diff --git a/.cargo/config.in b/.cargo/config.in index 0c77993e309f..296253d3fd7c 100644 --- a/.cargo/config.in +++ b/.cargo/config.in @@ -15,7 +15,7 @@ rev = "6a866fdad2ca880df9b87fcbc9921abac1e91914" [source."https://github.com/mozilla/neqo"] git = "https://github.com/mozilla/neqo" replace-with = "vendored-sources" -tag = "v0.1.13" +tag = "v0.1.12" [source."https://github.com/kvark/spirv_cross"] branch = "wgpu" diff --git a/Cargo.lock b/Cargo.lock index e106842c801f..7dafda6b8e62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,8 +2571,8 @@ checksum = "a2983372caf4480544083767bf2d27defafe32af49ab4df3a0b7fc90793a3664" [[package]] name = "neqo-common" -version = "0.1.13" -source = "git+https://github.com/mozilla/neqo?tag=v0.1.13#1bc70df50cb3e22ec6820e2a746a506317977604" +version = "0.1.12" +source = "git+https://github.com/mozilla/neqo?tag=v0.1.12#ccb7c07326230c4b1dc68d98a15ce364a4718c15" dependencies = [ "env_logger", "lazy_static", @@ -2582,8 +2582,8 @@ dependencies = [ [[package]] name = "neqo-crypto" -version = "0.1.13" -source = "git+https://github.com/mozilla/neqo?tag=v0.1.13#1bc70df50cb3e22ec6820e2a746a506317977604" +version = "0.1.12" +source = "git+https://github.com/mozilla/neqo?tag=v0.1.12#ccb7c07326230c4b1dc68d98a15ce364a4718c15" dependencies = [ "bindgen", "log", @@ -2595,8 +2595,8 @@ dependencies = [ [[package]] name = "neqo-http3" -version = "0.1.13" -source = "git+https://github.com/mozilla/neqo?tag=v0.1.13#1bc70df50cb3e22ec6820e2a746a506317977604" +version = "0.1.12" +source = "git+https://github.com/mozilla/neqo?tag=v0.1.12#ccb7c07326230c4b1dc68d98a15ce364a4718c15" dependencies = [ "log", "neqo-common", @@ -2609,8 +2609,8 @@ dependencies = [ [[package]] name = "neqo-qpack" -version = "0.1.13" -source = "git+https://github.com/mozilla/neqo?tag=v0.1.13#1bc70df50cb3e22ec6820e2a746a506317977604" +version = "0.1.12" +source = "git+https://github.com/mozilla/neqo?tag=v0.1.12#ccb7c07326230c4b1dc68d98a15ce364a4718c15" dependencies = [ "log", "neqo-common", @@ -2620,13 +2620,14 @@ dependencies = [ [[package]] name = "neqo-transport" -version = "0.1.13" -source = "git+https://github.com/mozilla/neqo?tag=v0.1.13#1bc70df50cb3e22ec6820e2a746a506317977604" +version = "0.1.12" +source = "git+https://github.com/mozilla/neqo?tag=v0.1.12#ccb7c07326230c4b1dc68d98a15ce364a4718c15" dependencies = [ "lazy_static", "log", "neqo-common", "neqo-crypto", + "rand", "smallvec 1.2.0", ] diff --git a/netwerk/protocol/http/nsHttp.cpp b/netwerk/protocol/http/nsHttp.cpp index fd01e6b5a75f..cb0a72b3ba09 100644 --- a/netwerk/protocol/http/nsHttp.cpp +++ b/netwerk/protocol/http/nsHttp.cpp @@ -27,7 +27,7 @@ namespace mozilla { namespace net { -const nsCString kHttp3Version = NS_LITERAL_CSTRING("h3-25"); +const nsCString kHttp3Version = NS_LITERAL_CSTRING("h3-24"); // define storage for all atoms namespace nsHttp { diff --git a/netwerk/protocol/http/nsHttp.h b/netwerk/protocol/http/nsHttp.h index d2bdcded16c3..93a34bdc16bf 100644 --- a/netwerk/protocol/http/nsHttp.h +++ b/netwerk/protocol/http/nsHttp.h @@ -54,7 +54,7 @@ enum class SpdyVersion { }; extern const nsCString kHttp3Version; -const char kHttp3VersionHEX[] = "ff00000019"; // this is draft 25. +const char kHttp3VersionHEX[] = "ff00000018"; // this is draft 24. //----------------------------------------------------------------------------- // http connection capabilities diff --git a/netwerk/socket/neqo_glue/Cargo.toml b/netwerk/socket/neqo_glue/Cargo.toml index fb8eb088e214..29e28b38c0fa 100644 --- a/netwerk/socket/neqo_glue/Cargo.toml +++ b/netwerk/socket/neqo_glue/Cargo.toml @@ -8,16 +8,16 @@ edition = "2018" name = "neqo_glue" [dependencies] -neqo-http3 = { tag = "v0.1.13", git = "https://github.com/mozilla/neqo" } -neqo-transport = { tag = "v0.1.13", git = "https://github.com/mozilla/neqo" } -neqo-common = { tag = "v0.1.13", git = "https://github.com/mozilla/neqo" } +neqo-http3 = { tag = "v0.1.12", git = "https://github.com/mozilla/neqo" } +neqo-transport = { tag = "v0.1.12", git = "https://github.com/mozilla/neqo" } +neqo-common = { tag = "v0.1.12", git = "https://github.com/mozilla/neqo" } nserror = { path = "../../../xpcom/rust/nserror" } nsstring = { path = "../../../xpcom/rust/nsstring" } xpcom = { path = "../../../xpcom/rust/xpcom" } thin-vec = { version = "0.1.0", features = ["gecko-ffi"] } [dependencies.neqo-crypto] -tag = "v0.1.13" +tag = "v0.1.12" git = "https://github.com/mozilla/neqo" default-features = false features = ["gecko"] diff --git a/third_party/rust/neqo-common/.cargo-checksum.json b/third_party/rust/neqo-common/.cargo-checksum.json index 35305ba359b7..5e57d8d65c8a 100644 --- a/third_party/rust/neqo-common/.cargo-checksum.json +++ b/third_party/rust/neqo-common/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"679490817e9489b056470e592edd623a607cbe613195f7fbd75eb92811836cb0","src/codec.rs":"00846df0051f32ec8b75b2f8e0344422e0693acbd4151aaec31e3ae02d6e696c","src/datagram.rs":"4beb13d5ea7927df6801fbe684dc231626c1856010eaef975d866ee66e894a45","src/incrdecoder.rs":"7b7b7fba57714a3baf0fe881010a9f5a9814bf26b9283a6d56d1c44010cbd822","src/lib.rs":"f6ee17bc45cafdccb562340a4d253a517c5366a74d07c38960aedc2554fe783c","src/log.rs":"943e4e332400d94805d60f965d1d0ae7aad180f6d5b50936d0bd9e085bbc1502","src/once.rs":"d8b2bf7a9e3ce83bdd7f29d8f73ce7ad0268c9618ae7255028fea9f90c9c9fd6","src/timer.rs":"56082a6ecb45bd31c7c677c4c1f0830e55821c860e70b5637b2015fa3be63743","tests/log.rs":"480b165b7907ec642c508b303d63005eee1427115d6973a349eaf6b2242ed18d"},"package":null} \ No newline at end of file +{"files":{"Cargo.toml":"fda2ba8d82e1c85cd700989e553a40863fe720de05b54c4c77dccc80257446c4","src/codec.rs":"6a35a0b4284b9279f1f937ba691ad7e4994a66343b92a9b69524baa71433d811","src/datagram.rs":"c0d0bfabd35f51dee6ad3edc5477e3b53f1cd24af78a756ce96cf37c78223a61","src/incrdecoder.rs":"7b7b7fba57714a3baf0fe881010a9f5a9814bf26b9283a6d56d1c44010cbd822","src/lib.rs":"f6ee17bc45cafdccb562340a4d253a517c5366a74d07c38960aedc2554fe783c","src/log.rs":"943e4e332400d94805d60f965d1d0ae7aad180f6d5b50936d0bd9e085bbc1502","src/once.rs":"d8b2bf7a9e3ce83bdd7f29d8f73ce7ad0268c9618ae7255028fea9f90c9c9fd6","src/timer.rs":"56082a6ecb45bd31c7c677c4c1f0830e55821c860e70b5637b2015fa3be63743","tests/log.rs":"23addde558449c79ac71877e1a96f24069d7e85839ec8464cb1f4e60032f29b1"},"package":null} \ No newline at end of file diff --git a/third_party/rust/neqo-common/Cargo.toml b/third_party/rust/neqo-common/Cargo.toml index d21caa34b258..ece3508757c1 100644 --- a/third_party/rust/neqo-common/Cargo.toml +++ b/third_party/rust/neqo-common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neqo-common" -version = "0.1.13" +version = "0.1.12" authors = ["Bobby Holley "] edition = "2018" license = "MIT/Apache-2.0" diff --git a/third_party/rust/neqo-common/src/codec.rs b/third_party/rust/neqo-common/src/codec.rs index ec4ea4d7e539..a9ba185090a0 100644 --- a/third_party/rust/neqo-common/src/codec.rs +++ b/third_party/rust/neqo-common/src/codec.rs @@ -29,12 +29,6 @@ impl<'a> Decoder<'a> { self.buf.len() - self.offset } - /// The number of bytes from the underlying slice that have been decoded. - #[must_use] - pub fn offset(&self) -> usize { - self.offset - } - /// Skip n bytes. Panics if `n` is too large. pub fn skip(&mut self, n: usize) { assert!(self.remaining() >= n); @@ -80,7 +74,7 @@ impl<'a> Decoder<'a> { } /// Decodes arbitrary data. - pub fn decode(&mut self, n: usize) -> Option<&'a [u8]> { + pub fn decode(&mut self, n: usize) -> Option<&[u8]> { if self.remaining() < n { return None; } @@ -120,13 +114,13 @@ impl<'a> Decoder<'a> { } /// Decodes the rest of the buffer. Infallible. - pub fn decode_remainder(&mut self) -> &'a [u8] { + pub fn decode_remainder(&mut self) -> &[u8] { let res = &self.buf[self.offset..]; self.offset = self.buf.len(); res } - fn decode_checked(&mut self, n: Option) -> Option<&'a [u8]> { + fn decode_checked(&mut self, n: Option) -> Option<&[u8]> { let len = match n { Some(l) => l, _ => return None, @@ -142,13 +136,13 @@ impl<'a> Decoder<'a> { } /// Decodes a TLS-style length-prefixed buffer. - pub fn decode_vec(&mut self, n: usize) -> Option<&'a [u8]> { + pub fn decode_vec(&mut self, n: usize) -> Option<&[u8]> { let len = self.decode_uint(n); self.decode_checked(len) } /// Decodes a QUIC varint-length-prefixed buffer. - pub fn decode_vvec(&mut self) -> Option<&'a [u8]> { + pub fn decode_vvec(&mut self) -> Option<&[u8]> { let len = self.decode_varint(); self.decode_checked(len) } @@ -339,11 +333,6 @@ impl Encoder { self.buf[start..].rotate_right(count); self } - - /// Truncate the encoder to the given size. - pub fn truncate(&mut self, len: usize) { - self.buf.truncate(len); - } } impl Debug for Encoder { diff --git a/third_party/rust/neqo-common/src/datagram.rs b/third_party/rust/neqo-common/src/datagram.rs index 03c1eb945d40..b13366ef4652 100644 --- a/third_party/rust/neqo-common/src/datagram.rs +++ b/third_party/rust/neqo-common/src/datagram.rs @@ -7,9 +7,7 @@ use std::net::SocketAddr; use std::ops::Deref; -use crate::hex; - -#[derive(PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct Datagram { src: SocketAddr, dst: SocketAddr, @@ -43,15 +41,3 @@ impl Deref for Datagram { &self.d } } - -impl std::fmt::Debug for Datagram { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "Datagram {:?}->{:?}: {}", - self.src, - self.dst, - hex(&self.d) - ) - } -} diff --git a/third_party/rust/neqo-common/tests/log.rs b/third_party/rust/neqo-common/tests/log.rs index 33b42d1411f5..2da23be2734e 100644 --- a/third_party/rust/neqo-common/tests/log.rs +++ b/third_party/rust/neqo-common/tests/log.rs @@ -5,7 +5,6 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::use_self)] use neqo_common::{qdebug, qerror, qinfo, qtrace, qwarn}; diff --git a/third_party/rust/neqo-crypto/.cargo-checksum.json b/third_party/rust/neqo-crypto/.cargo-checksum.json index 618f24ef87f0..da462d3bad91 100644 --- a/third_party/rust/neqo-crypto/.cargo-checksum.json +++ b/third_party/rust/neqo-crypto/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"7590c2cd1d5286f1546c7a09fc3dd07b2b0182e81bcc551a323ff4965230f36c","TODO":"ac0f1c2ebcca03f5b3c0cc56c5aedbb030a4b511e438bc07a57361c789f91e9f","bindings/bindings.toml":"00ff7348732c956b4f8829f00df2b18b3a7211f5fa2a4cea4ae40c0f859e5f50","bindings/mozpkix.hpp":"77072c8bb0f6eb6bfe8cbadc111dcd92e0c79936d13f2e501aae1e5d289a6675","bindings/nspr_err.h":"2d5205d017b536c2d838bcf9bc4ec79f96dd50e7bb9b73892328781f1ee6629d","bindings/nspr_error.h":"e41c03c77b8c22046f8618832c9569fbcc7b26d8b9bbc35eea7168f35e346889","bindings/nspr_io.h":"085b289849ef0e77f88512a27b4d9bdc28252bd4d39c6a17303204e46ef45f72","bindings/nspr_time.h":"2e637fd338a5cf0fd3fb0070a47f474a34c2a7f4447f31b6875f5a9928d0a261","bindings/nss_ciphers.h":"95ec6344a607558b3c5ba8510f463b6295f3a2fb3f538a01410531045a5f62d1","bindings/nss_init.h":"ef49045063782fb612aff459172cc6a89340f15005808608ade5320ca9974310","bindings/nss_p11.h":"0b81e64fe6db49b2ecff94edd850be111ef99ec11220e88ceb1c67be90143a78","bindings/nss_secerr.h":"713e8368bdae5159af7893cfa517dabfe5103cede051dee9c9557c850a2defc6","bindings/nss_ssl.h":"af222fb957b989e392e762fa2125c82608a0053aff4fb97e556691646c88c335","bindings/nss_sslerr.h":"24b97f092183d8486f774cdaef5030d0249221c78343570d83a4ee5b594210ae","bindings/nss_sslopt.h":"b7807eb7abdad14db6ad7bc51048a46b065a0ea65a4508c95a12ce90e59d1eea","build.rs":"eb324a3f076f0079acc0332379bc68ba6fb1232c3f9e44ef63334fe625d569c1","src/aead.rs":"2013408fbcf9e93331ae14d9d6bdd096966f125b3cf48f83e671f537e89d4e77","src/agent.rs":"5f460010eff4a604b23c456b5cff132f995b30767f0188285fdf39d7724ecf6f","src/agentio.rs":"aeb91f3e4c4cc5b8a816307747090c5df02924801511f9523f9d767fe9dd67e9","src/auth.rs":"71ac7e297a5f872d26cf67b6bbd96e4548ea38374bdd84c1094f76a5de4ed1cb","src/cert.rs":"fd3fd2bbb38754bdcee3898549feae412943c9f719032531c1ad6e61783b5394","src/constants.rs":"e756c07525bd7c2ff271e504708f903b3ede0a3ae821446bd37701055eb11f5f","src/err.rs":"04f38831ca62d29d8aadfe9daf95fd29e68ece184e6d3e00bfb9ee1d12744033","src/exp.rs":"61586662407359c1ecb8ed4987bc3c702f26ba2e203a091a51b6d6363cbd510f","src/ext.rs":"bf7b5f23caf26ab14fba3baf0823dd093e4194f759779e4cfd608478312ed58c","src/hkdf.rs":"1bb57806bbf67af74966bb2bb724de9d6b0094c6f5cddbe12d46292d58ba1f16","src/hp.rs":"0384bc676d8cc66a2cfec7be9df176f04557e4f1424c6d19d03ba5687920ac86","src/lib.rs":"49e0ad22fb5aec2e0864b907cb6d419389d53014e33c147f53198b440ec8929f","src/p11.rs":"6e94cbb594b709c3081449bf50d9961d36648b5db95fb824779bff4f45125ad2","src/prio.rs":"bc4e97049563b136cb7b39f5171e7909d56a77ed46690aaacb781eeb4a4743e0","src/replay.rs":"9bc5826cc8be6afe787f0d403b3958245efce9bfbc7b3100734e5aec3f8b9753","src/result.rs":"cef34dfcb907723e195b56501132e4560e250b327783cb5e41201da5b63e9b5c","src/secrets.rs":"531ec0de048f55108f2612d8f330bee18ffd58b3b26124ca290cc14cec8671dc","src/selfencrypt.rs":"02e963e8b9ea0802f7ee64384e5ccef3e31420e75bc1aacd02270dd504ffbdb1","src/ssl.rs":"ee0e638bd0a6ce2f01ecb6a1c1a203ac7a7ae8145b889a0d6f2015f98d65c4b4","src/time.rs":"d77f0f276385603633b2078f05ff9b4dddc8cfb84c595697689876b6996f69d2","tests/aead.rs":"cccac271087fe37d0a890e5da04984bbfacb4bc12331473dfc189e4d6ebff5f2","tests/agent.rs":"4fa8fa803266b985e9b6329e6a218fe7bd779200b8e0cfa94f5813e0ccc10995","tests/ext.rs":"f5edc1f229703f786ec31a8035465c00275223f14a3c4abe52f3c7cf2686cc03","tests/handshake.rs":"bcc687c0e1b485658847faf28a9f5dbfdb297812bed1bd2e80593d5f9e1fee36","tests/hkdf.rs":"0e4853f629050ba4d8069be52b7a441b670d1abaf6b8cd670a8215e0b88beb37","tests/hp.rs":"e6dd3cb4bceebc6fca8f270d8302ef34e14bda6c91fc4f9342ba1681be57ee03","tests/init.rs":"55df7cb95deb629f8701b55a8bcb91e797f30fb10e847a36a0a5a4e80488b002","tests/selfencrypt.rs":"60bfe8a0729cdaa6c2171146083266fa0e625a1d98b5f8735cd22b725d32398b"},"package":null} \ No newline at end of file +{"files":{"Cargo.toml":"7636458c97a2cffc541b679cb0fc53a5aa4ea5c9dc462e1a383f7c471f112d35","TODO":"ac0f1c2ebcca03f5b3c0cc56c5aedbb030a4b511e438bc07a57361c789f91e9f","bindings/bindings.toml":"0f305bda9513e7fb4b521df79912ad5ba21784377b84f4b531895619e561f356","bindings/mozpkix.hpp":"77072c8bb0f6eb6bfe8cbadc111dcd92e0c79936d13f2e501aae1e5d289a6675","bindings/nspr_err.h":"2d5205d017b536c2d838bcf9bc4ec79f96dd50e7bb9b73892328781f1ee6629d","bindings/nspr_error.h":"e41c03c77b8c22046f8618832c9569fbcc7b26d8b9bbc35eea7168f35e346889","bindings/nspr_io.h":"085b289849ef0e77f88512a27b4d9bdc28252bd4d39c6a17303204e46ef45f72","bindings/nspr_time.h":"2e637fd338a5cf0fd3fb0070a47f474a34c2a7f4447f31b6875f5a9928d0a261","bindings/nss_ciphers.h":"95ec6344a607558b3c5ba8510f463b6295f3a2fb3f538a01410531045a5f62d1","bindings/nss_init.h":"ef49045063782fb612aff459172cc6a89340f15005808608ade5320ca9974310","bindings/nss_p11.h":"0b81e64fe6db49b2ecff94edd850be111ef99ec11220e88ceb1c67be90143a78","bindings/nss_secerr.h":"713e8368bdae5159af7893cfa517dabfe5103cede051dee9c9557c850a2defc6","bindings/nss_ssl.h":"af222fb957b989e392e762fa2125c82608a0053aff4fb97e556691646c88c335","bindings/nss_sslerr.h":"24b97f092183d8486f774cdaef5030d0249221c78343570d83a4ee5b594210ae","bindings/nss_sslopt.h":"b7807eb7abdad14db6ad7bc51048a46b065a0ea65a4508c95a12ce90e59d1eea","build.rs":"363243f6fb484c081dc73ad456ec2f7577525f94113930c49f3c466784405a70","src/aead.rs":"b598dc13a6fd1e97848571ef130202abb3ad05eab95334668c06b2480387ef5b","src/agent.rs":"1a0af9b1354023c976120c1ec92f2b5e25f9427cbc61dfa9772a267a47882731","src/agentio.rs":"712ff073e1fd9a55169481502bc7d2b78e0f3b498cfa55635c8061867d511cd1","src/auth.rs":"71ac7e297a5f872d26cf67b6bbd96e4548ea38374bdd84c1094f76a5de4ed1cb","src/cert.rs":"fd3fd2bbb38754bdcee3898549feae412943c9f719032531c1ad6e61783b5394","src/constants.rs":"75dec8e3c74326f492a115a0e7a487daba32eba30bcbd64d2223333b3caa4008","src/err.rs":"04f38831ca62d29d8aadfe9daf95fd29e68ece184e6d3e00bfb9ee1d12744033","src/exp.rs":"61586662407359c1ecb8ed4987bc3c702f26ba2e203a091a51b6d6363cbd510f","src/ext.rs":"e9b251fd156b49eff221c079ce3c4095cd7d13cd52e711a6a08f9682764073a5","src/hkdf.rs":"6d44f63493f0c558a23339f88fe766f8afdb0bda3dc11a79e8a99d3c8d0b6acb","src/hp.rs":"854ce7b9d44892fbb01ac4078b84266771a9254cebfea5b94e7f4b4a7fb1b946","src/lib.rs":"9ada53450e66cdcf944a72e4b23feb06a3c3c92813e79a2ac122dea6319a6d81","src/p11.rs":"7a755e56372b4b037359cf9de9ba831bf986599e371a47d56630a2cc3dcc82da","src/prio.rs":"0e213056f6bf0c797c2cfe13c6d14dbb64a64b1218fff21cbf36fb3309b852f9","src/replay.rs":"01eae2accfefbc26719fcccd4bcb8c1ea6400ab96fbb696ecdb8f32164f931a2","src/result.rs":"d76c7bc5e99c80a5a124909ab586cdb91d894856e52f5146430da43584a6d6c1","src/secrets.rs":"09b26118995b3b2301646f5e9fa9ce25ad813e4780184eda53ab7eff8d1da5c5","src/selfencrypt.rs":"3a642f95073e329f9211468304605dd3953d9fbeef24cf64f00dd541ef6a30ee","src/ssl.rs":"d8bf4aa4869e7d161a2e862d7a628484f8736273823320fc6eb693ba86b8aef0","src/time.rs":"d77f0f276385603633b2078f05ff9b4dddc8cfb84c595697689876b6996f69d2","tests/aead.rs":"4472c50dff9e70533bfa9bc0c964011b832335066d3e3e7fe4d3240f40291b7f","tests/agent.rs":"451cf24b3f211f7b31fffea58f2a0d9d760c9b1af8dc7c47be663c099f4cfd65","tests/ext.rs":"5249e866a5a0b57fee733f99a7d9cba16ef63358de5657c9b69488d8e6e680a4","tests/handshake.rs":"a9dab4781c63b58fecfe9434275fcea53d08981ce99980113f47c10416f5ba63","tests/hkdf.rs":"afc6a7654c6222ff17f68dbba7ca16b0e13044e1107e19e9b9ca6fa2bd473bce","tests/hp.rs":"77c4998ee25ebd8ffc3e00f9bf79d03a42df3a291d0fb371b4e8ea680876cf4b","tests/init.rs":"0243ec4b6052a8ce83494b815fca3a19aed3fcb44d87e0206faeb21529d63445","tests/selfencrypt.rs":"88eec5c3421d5a9efe7fc4ddac749bed8afc1789c6cc1a7701ebf5d5f23c58ec"},"package":null} \ No newline at end of file diff --git a/third_party/rust/neqo-crypto/Cargo.toml b/third_party/rust/neqo-crypto/Cargo.toml index 993c2b65880b..1351f003fe83 100644 --- a/third_party/rust/neqo-crypto/Cargo.toml +++ b/third_party/rust/neqo-crypto/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neqo-crypto" -version = "0.1.13" +version = "0.1.12" authors = ["Martin Thomson "] edition = "2018" build = "build.rs" diff --git a/third_party/rust/neqo-crypto/bindings/bindings.toml b/third_party/rust/neqo-crypto/bindings/bindings.toml index c93a9ab45409..c46f1c687151 100644 --- a/third_party/rust/neqo-crypto/bindings/bindings.toml +++ b/third_party/rust/neqo-crypto/bindings/bindings.toml @@ -39,7 +39,6 @@ functions = [ "SSL_ImportFD", "SSL_NamedGroupConfig", "SSL_OptionSet", - "SSL_OptionGetDefault", "SSL_PeerCertificate", "SSL_PeerCertificateChain", "SSL_PeerSignedCertTimestamps", diff --git a/third_party/rust/neqo-crypto/build.rs b/third_party/rust/neqo-crypto/build.rs index 09a077374a8d..55ccfb7168f1 100644 --- a/third_party/rust/neqo-crypto/build.rs +++ b/third_party/rust/neqo-crypto/build.rs @@ -129,10 +129,7 @@ fn get_bash() -> PathBuf { } fn build_nss(dir: PathBuf) { - let mut build_nss = vec![ - String::from("./build.sh"), - String::from("-Ddisable_tests=1"), - ]; + let mut build_nss = vec![String::from("./build.sh")]; if is_debug() { build_nss.push(String::from("--static")); } else { @@ -170,7 +167,12 @@ fn dynamic_link_both(extra_libs: &[&str]) { } } -fn static_link() { +fn static_link(nsstarget: &PathBuf) { + let lib_dir = nsstarget.join("lib"); + println!( + "cargo:rustc-link-search=native={}", + lib_dir.to_str().unwrap() + ); let mut static_libs = vec![ "certdb", "certhi", @@ -297,8 +299,9 @@ fn setup_standalone() -> Vec { "cargo:rustc-link-search=native={}", nsslibdir.to_str().unwrap() ); + if is_debug() { - static_link(); + static_link(&nsstarget); } else { dynamic_link(); } diff --git a/third_party/rust/neqo-crypto/src/aead.rs b/third_party/rust/neqo-crypto/src/aead.rs index 50ef14908601..5e9a42cd6bda 100644 --- a/third_party/rust/neqo-crypto/src/aead.rs +++ b/third_party/rust/neqo-crypto/src/aead.rs @@ -53,13 +53,7 @@ pub struct Aead { ctx: AeadContext, } -// TODO(mt) move unused_self once https://github.com/rust-lang/rust-clippy/issues/5053 is fixed -#[allow(clippy::unused_self)] impl Aead { - /// Create a new AEAD based on the indicated TLS version and cipher suite. - /// - /// # Errors - /// Returns `Error` when the supporting NSS functions fail. pub fn new(version: Version, cipher: Cipher, secret: &SymKey, prefix: &str) -> Res { let s: *mut PK11SymKey = **secret; unsafe { Self::from_raw(version, cipher, s, prefix) } @@ -98,9 +92,6 @@ impl Aead { /// /// The space provided in `output` needs to be larger than `input` by /// the value provided in `Aead::expansion`. - /// - /// # Errors - /// If the input can't be protected or any input is too large for NSS. pub fn encrypt<'a>( &self, count: u64, @@ -130,9 +121,6 @@ impl Aead { /// Note that NSS insists upon having extra space available for decryption, so /// the buffer for `output` should be the same length as `input`, even though /// the final result will be shorter. - /// - /// # Errors - /// If the input isn't authenticated or any input is too large for NSS. pub fn decrypt<'a>( &self, count: u64, diff --git a/third_party/rust/neqo-crypto/src/agent.rs b/third_party/rust/neqo-crypto/src/agent.rs index 9f376316568a..c8ca44253ea5 100644 --- a/third_party/rust/neqo-crypto/src/agent.rs +++ b/third_party/rust/neqo-crypto/src/agent.rs @@ -4,8 +4,8 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -pub use crate::agentio::{as_c_void, Record, RecordList}; use crate::agentio::{AgentIo, METHODS}; +pub use crate::agentio::{Record, RecordList}; use crate::assert_initialized; use crate::auth::AuthenticationStatus; pub use crate::cert::CertificateInfo; @@ -19,7 +19,7 @@ use crate::secrets::SecretHolder; use crate::ssl::{self, PRBool}; use crate::time::{PRTime, Time}; -use neqo_common::{matches, qdebug, qinfo, qtrace, qwarn}; +use neqo_common::{qdebug, qinfo, qwarn}; use std::cell::RefCell; use std::convert::{TryFrom, TryInto}; use std::ffi::CString; @@ -43,13 +43,11 @@ pub enum HandshakeState { impl HandshakeState { #[must_use] - pub fn is_connected(&self) -> bool { - matches!(self, Self::Complete(_)) - } - - #[must_use] - pub fn is_final(&self) -> bool { - matches!(self, Self::Complete(_) | Self::Failed(_)) + pub fn connected(&self) -> bool { + match self { + Self::Complete(_) => true, + _ => false, + } } } @@ -79,7 +77,7 @@ fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res> { } _ => None, }; - qtrace!([format!("{:p}", fd)], "got ALPN {:?}", alpn); + qinfo!([format!("{:p}", fd)], "got ALPN {:?}", alpn); Ok(alpn) } @@ -125,7 +123,7 @@ impl SecretAgentPreInfo { } #[must_use] pub fn max_early_data(&self) -> usize { - usize::try_from(self.info.maxEarlyDataSize).unwrap() + self.info.maxEarlyDataSize as usize } #[must_use] pub fn alpn(&self) -> Option<&String> { @@ -227,24 +225,24 @@ pub struct SecretAgent { impl SecretAgent { fn new() -> Res { - let mut io = Box::pin(AgentIo::new()); - let fd = Self::create_fd(&mut io)?; - Ok(Self { - fd, + let mut agent = Self { + fd: null_mut(), secrets: SecretHolder::default(), raw: None, - io, + io: Pin::new(Box::new(AgentIo::new())), state: HandshakeState::New, - auth_required: Box::pin(false), - alert: Box::pin(None), - now: Box::pin(0), + auth_required: Pin::new(Box::new(false)), + alert: Pin::new(Box::new(None)), + now: Pin::new(Box::new(0)), extension_handlers: Vec::new(), inf: None, no_eoed: false, - }) + }; + agent.create_fd()?; + Ok(agent) } // Create a new SSL file descriptor. @@ -254,7 +252,7 @@ impl SecretAgent { // minimal, but it means that the two forms need casts to translate // between them. ssl::PRFileDesc is left as an opaque type, as the // ssl::SSL_* APIs only need an opaque type. - fn create_fd(io: &mut Pin>) -> Res<*mut ssl::PRFileDesc> { + fn create_fd(&mut self) -> Res<()> { assert_initialized(); let label = CString::new("sslwrapper")?; let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) }; @@ -264,14 +262,15 @@ impl SecretAgent { return Err(Error::CreateSslSocket); } let fd = unsafe { - (*base_fd).secret = as_c_void(io) as *mut _; + (*base_fd).secret = &mut *self.io as *mut AgentIo as *mut _; ssl::SSL_ImportFD(null_mut(), base_fd as *mut ssl::PRFileDesc) }; if fd.is_null() { unsafe { prio::PR_Close(base_fd) }; return Err(Error::CreateSslSocket); } - Ok(fd) + self.fd = fd; + Ok(()) } unsafe extern "C" fn auth_complete_hook( @@ -321,7 +320,7 @@ impl SecretAgent { ssl::SSL_AuthCertificateHook( self.fd, Some(Self::auth_complete_hook), - as_c_void(&mut self.auth_required), + &mut *self.auth_required as *mut bool as *mut c_void, ) })?; @@ -329,21 +328,24 @@ impl SecretAgent { ssl::SSL_AlertSentCallback( self.fd, Some(Self::alert_sent_cb), - as_c_void(&mut self.alert), + &mut *self.alert as *mut Option as *mut c_void, ) })?; // TODO(mt) move to time.rs so we can remove PRTime definition from nss_ssl bindings. - unsafe { ssl::SSL_SetTimeFunc(self.fd, Some(Self::time_func), as_c_void(&mut self.now)) }?; + unsafe { + ssl::SSL_SetTimeFunc( + self.fd, + Some(Self::time_func), + &mut *self.now as *mut PRTime as *mut c_void, + ) + }?; self.configure()?; secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, is_server as ssl::PRBool) }) } /// Default configuration. - /// - /// # Errors - /// If `set_version_range` fails. fn configure(&mut self) -> Res<()> { self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; self.set_option(ssl::Opt::Locking, false)?; @@ -352,10 +354,6 @@ impl SecretAgent { Ok(()) } - /// Set the versions that are supported. - /// - /// # Errors - /// If the range of versions isn't supported. pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> { let range = ssl::SSLVersionRange { min: min as ssl::PRUint16, @@ -364,10 +362,6 @@ impl SecretAgent { secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) }) } - /// Enable a set of ciphers. Note that the order of these is not respected. - /// - /// # Errors - /// If NSS can't enable or disable ciphers. pub fn enable_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> { let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() }; let cipher_count = unsafe { ssl::SSL_GetNumImplementedCiphers() } as usize; @@ -386,10 +380,6 @@ impl SecretAgent { Ok(()) } - /// Set key exchange groups. - /// - /// # Errors - /// If the underlying API fails (which shouldn't happen). pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> { // SSLNamedGroup is a different size to Group, so copy one by one. let group_vec: Vec<_> = groups @@ -404,9 +394,6 @@ impl SecretAgent { } /// Set TLS options. - /// - /// # Errors - /// Returns an error if the option or option value is invalid; i.e., never. pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> { secstatus_to_res(unsafe { ssl::SSL_OptionSet(self.fd, opt.as_int(), opt.map_enabled(value)) @@ -414,9 +401,6 @@ impl SecretAgent { } /// Enable 0-RTT. - /// - /// # Errors - /// See `set_option`. pub fn enable_0rtt(&mut self) -> Res<()> { self.set_option(ssl::Opt::EarlyData, true) } @@ -432,9 +416,6 @@ impl SecretAgent { /// /// This asserts if no items are provided, or if any individual item is longer than /// 255 octets in length. - /// - /// # Errors - /// This should always panic rather than return an error. pub fn set_alpn(&mut self, protocols: &[impl AsRef]) -> Res<()> { // Validate and set length. let mut encoded_len = protocols.len(); @@ -478,9 +459,6 @@ impl SecretAgent { /// This can be called multiple times with different values for `ext`. The handler is provided as /// Rc> so that the caller is able to hold a reference to the handler and later access any /// state that it accumulates. - /// - /// # Errors - /// When the extension handler can't be successfully installed. pub fn extension_handler( &mut self, ext: Extension, @@ -521,9 +499,6 @@ impl SecretAgent { /// /// This includes whether 0-RTT was accepted and any information related to that. /// Calling this function collects all the relevant information. - /// - /// # Errors - /// When the underlying socket functions fail. pub fn preinfo(&self) -> Res { SecretAgentPreInfo::new(self.fd) } @@ -573,16 +548,13 @@ impl SecretAgent { Ok(()) } - /// Drive the TLS handshake, taking bytes from `input` and putting - /// any bytes necessary into `output`. - /// This takes the current time as `now`. - /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake - /// is complete and how many bytes were written to `output`, respectively. - /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this - /// function if you want to proceed, because this will mark the certificate as OK. - /// - /// # Errors - /// When the handshake fails this returns an error. + // Drive the TLS handshake, taking bytes from @input and putting + // any bytes necessary into @output. + // This takes the current time as @now. + // On success a tuple of a HandshakeState and usize indicate whether the handshake + // is complete and how many bytes were written to @output, respectively. + // If the state is HandshakeState::AuthenticationPending, then ONLY call this + // function if you want to proceed, because this will mark the certificate as OK. pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res> { *self.now = Time::from(now).try_into()?; self.set_raw(false)?; @@ -632,15 +604,12 @@ impl SecretAgent { Ok(()) } - /// Drive the TLS handshake, but get the raw content of records, not - /// protected records as bytes. This function is incompatible with - /// `handshake()`; use either this or `handshake()` exclusively. - /// - /// Ideally, this only includes records from the current epoch. - /// If you send data from multiple epochs, you might end up being sad. - /// - /// # Errors - /// When the handshake fails this returns an error. + // Drive the TLS handshake, but get the raw content of records, not + // protected records as bytes. This function is incompatible with + // handshake(); use either this or handshake() exclusively. + // + // Ideally, this only includes records from the current epoch. + // If you send data from multiple epochs, you might end up being sad. pub fn handshake_raw(&mut self, now: Instant, input: Option) -> Res { *self.now = Time::from(now).try_into()?; let mut records = self.setup_raw()?; @@ -673,44 +642,18 @@ impl SecretAgent { Ok(*Pin::into_inner(records)) } - pub fn close(&mut self) { - // It should be safe to close multiple times. - if self.fd.is_null() { - return; - } - if let Some(true) = self.raw { - // Need to hold the record list in scope until the close is done. - let _records = self.setup_raw().expect("Can only close"); - unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) }; - } else { - // Need to hold the IO wrapper in scope until the close is done. - let _io = self.io.wrap(&[]); - unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) }; - }; - let _output = self.io.take_output(); - self.fd = null_mut(); - } - - /// State returns the status of the handshake. + // State returns the status of the handshake. #[must_use] pub fn state(&self) -> &HandshakeState { &self.state } - /// Take a read secret. This will only return a non-`None` value once. #[must_use] - pub fn read_secret(&mut self, epoch: Epoch) -> Option { - self.secrets.take_read(epoch) + pub fn read_secret(&self, epoch: Epoch) -> Option<&p11::SymKey> { + self.secrets.read().get(epoch) } - /// Take a write secret. #[must_use] - pub fn write_secret(&mut self, epoch: Epoch) -> Option { - self.secrets.take_write(epoch) - } -} - -impl Drop for SecretAgent { - fn drop(&mut self) { - self.close(); + pub fn write_secret(&self, epoch: Epoch) -> Option<&p11::SymKey> { + self.secrets.write().get(epoch) } } @@ -730,10 +673,6 @@ pub struct Client { } impl Client { - /// Create a new client agent. - /// - /// # Errors - /// Errors returned if the socket can't be created or configured. pub fn new(server_name: &str) -> Res { let mut agent = SecretAgent::new()?; let url = CString::new(server_name)?; @@ -741,7 +680,7 @@ impl Client { agent.ready(false)?; let mut client = Self { agent, - resumption: Box::pin(None), + resumption: Pin::new(Box::new(None)), }; client.ready()?; Ok(client) @@ -763,12 +702,11 @@ impl Client { } fn ready(&mut self) -> Res<()> { - let fd = self.fd; unsafe { ssl::SSL_SetResumptionTokenCallback( - fd, + self.fd, Some(Self::resumption_token_cb), - as_c_void(&mut self.resumption), + &mut *self.resumption as *mut Option> as *mut c_void, ) } } @@ -780,10 +718,6 @@ impl Client { } /// Enable resumption, using a token previously provided. - /// - /// # Errors - /// Error returned when the resumption token is invalid or - /// the socket is not able to use the value. pub fn set_resumption_token(&mut self, token: &[u8]) -> Res<()> { unsafe { ssl::SSL_SetResumptionToken( @@ -850,10 +784,6 @@ pub struct Server { } impl Server { - /// Create a new server agent. - /// - /// # Errors - /// Errors returned when NSS fails. pub fn new(certificates: &[impl AsRef]) -> Res { let mut agent = SecretAgent::new()?; @@ -923,22 +853,16 @@ impl Server { /// Enable 0-RTT. This shadows the function of the same name that can be accessed /// via the Deref implementation on Server. - /// - /// # Errors - /// Returns an error if the underlying NSS functions fail. pub fn enable_0rtt( &mut self, anti_replay: &AntiReplay, max_early_data: u32, checker: Box, ) -> Res<()> { - let mut check_state = Box::pin(ZeroRttCheckState::new(self.agent.fd, checker)); + let mut check_state = Pin::new(Box::new(ZeroRttCheckState::new(self.agent.fd, checker))); + let arg = &mut *check_state as *mut ZeroRttCheckState as *mut c_void; unsafe { - ssl::SSL_HelloRetryRequestCallback( - self.agent.fd, - Some(Self::hello_retry_cb), - as_c_void(&mut check_state), - ) + ssl::SSL_HelloRetryRequestCallback(self.agent.fd, Some(Self::hello_retry_cb), arg) }?; unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?; self.zero_rtt_check = Some(check_state); @@ -950,9 +874,6 @@ impl Server { /// Send a session ticket to the client. /// This adds |extra| application-specific content into that ticket. /// The records that are sent are captured and returned. - /// - /// # Errors - /// If NSS is unable to send a ticket, or if this agent is incorrectly configured. pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res { *self.agent.now = Time::from(now).try_into()?; let records = self.setup_raw()?; diff --git a/third_party/rust/neqo-crypto/src/agentio.rs b/third_party/rust/neqo-crypto/src/agentio.rs index 21f839d9f7f7..b0db7766e88d 100644 --- a/third_party/rust/neqo-crypto/src/agentio.rs +++ b/third_party/rust/neqo-crypto/src/agentio.rs @@ -26,11 +26,6 @@ type PrStatus = prio::PRStatus::Type; const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS; const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE; -/// Convert a pinned, boxed object into a void pointer. -pub fn as_c_void(pin: &mut Pin>) -> *mut c_void { - Pin::into_inner(pin.as_mut()) as *mut T as *mut c_void -} - // This holds the length of the slice, not the slice itself. #[derive(Default, Debug)] struct RecordLength { @@ -117,10 +112,9 @@ impl RecordList { /// Create a new record list. pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res>> { - let mut records = Box::pin(Self::default()); - unsafe { - ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records)) - }?; + let mut records = Pin::new(Box::new(Self::default())); + let records_ptr = &mut *records as *mut Self as *mut c_void; + unsafe { ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), records_ptr) }?; Ok(records) } } diff --git a/third_party/rust/neqo-crypto/src/constants.rs b/third_party/rust/neqo-crypto/src/constants.rs index b714c4bb6799..10ae877f619e 100644 --- a/third_party/rust/neqo-crypto/src/constants.rs +++ b/third_party/rust/neqo-crypto/src/constants.rs @@ -12,15 +12,7 @@ use crate::ssl; // for values outside of those that are defined here. pub type Alert = u8; - pub type Epoch = u16; -// TLS doesn't really have an "initial" concept that maps to QUIC so directly, -// but this should be clear enough. -pub const TLS_EPOCH_INITIAL: Epoch = 0 as Epoch; -pub const TLS_EPOCH_ZERO_RTT: Epoch = 1 as Epoch; -pub const TLS_EPOCH_HANDSHAKE: Epoch = 2 as Epoch; -// Also, we don't use TLS epochs > 3. -pub const TLS_EPOCH_APPLICATION_DATA: Epoch = 3 as Epoch; /// Rather than defining a type alias and a bunch of constants, which leads to a ton of repetition, /// use this macro. diff --git a/third_party/rust/neqo-crypto/src/ext.rs b/third_party/rust/neqo-crypto/src/ext.rs index 7ccbeec68c27..772eb46af701 100644 --- a/third_party/rust/neqo-crypto/src/ext.rs +++ b/third_party/rust/neqo-crypto/src/ext.rs @@ -4,7 +4,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::agentio::as_c_void; use crate::constants::*; use crate::err::Res; use crate::ssl::{ @@ -118,14 +117,11 @@ impl ExtensionTracker { } /// Use the provided handler to manage an extension. This is quite unsafe. - /// /// # Safety + /// /// The holder of this `ExtensionTracker` needs to ensure that it lives at /// least as long as the file descriptor, as NSS provides no way to remove /// an extension handler once it is configured. - /// - /// # Errors - /// If the underlying NSS API fails to register a handler. pub unsafe fn new( fd: *mut PRFileDesc, extension: Extension, @@ -144,15 +140,16 @@ impl ExtensionTracker { // This way, only this "outer" code deals with the reference count. let mut tracker = Self { extension, - handler: Box::pin(Box::new(handler)), + handler: Pin::new(Box::new(Box::new(handler))), }; + let p = &mut *tracker.handler as *mut BoxedExtensionHandler as *mut c_void; SSL_InstallExtensionHooks( fd, extension, Some(Self::extension_writer), - as_c_void(&mut tracker.handler), + p, Some(Self::extension_handler), - as_c_void(&mut tracker.handler), + p, )?; Ok(tracker) } diff --git a/third_party/rust/neqo-crypto/src/hkdf.rs b/third_party/rust/neqo-crypto/src/hkdf.rs index b6aba04018d5..3496fa055e53 100644 --- a/third_party/rust/neqo-crypto/src/hkdf.rs +++ b/third_party/rust/neqo-crypto/src/hkdf.rs @@ -34,7 +34,7 @@ experimental_api!(SSL_HkdfExpandLabel( secret: *mut *mut PK11SymKey, )); -fn key_size(version: Version, cipher: Cipher) -> Res { +pub fn key_size(version: Version, cipher: Cipher) -> Res { if version != TLS_VERSION_1_3 { return Err(Error::UnsupportedVersion); } @@ -45,18 +45,11 @@ fn key_size(version: Version, cipher: Cipher) -> Res { }) } -/// Generate a random key of the right size for the given suite. -/// -/// # Errors -/// Only if NSS fails. -pub fn generate_key(version: Version, cipher: Cipher) -> Res { - import_key(version, cipher, &random(key_size(version, cipher)?)) +pub fn generate_key(version: Version, cipher: Cipher, size: usize) -> Res { + import_key(version, cipher, &random(size)?) } /// Import a symmetric key for use with HKDF. -/// -/// # Errors -/// Errors returned if the key buffer is an incompatible size or the NSS functions fail. pub fn import_key(version: Version, cipher: Cipher, buf: &[u8]) -> Res { if version != TLS_VERSION_1_3 { return Err(Error::UnsupportedVersion); @@ -93,9 +86,6 @@ pub fn import_key(version: Version, cipher: Cipher, buf: &[u8]) -> Res { } /// Extract a PRK from the given salt and IKM using the algorithm defined in RFC 5869. -/// -/// # Errors -/// Errors returned if inputs are too large or the NSS functions fail. pub fn extract( version: Version, cipher: Cipher, @@ -115,9 +105,6 @@ pub fn extract( } /// Expand a PRK using the HKDF-Expand-Label function defined in RFC 8446. -/// -/// # Errors -/// Errors returned if inputs are too large or the NSS functions fail. pub fn expand_label( version: Version, cipher: Cipher, diff --git a/third_party/rust/neqo-crypto/src/hp.rs b/third_party/rust/neqo-crypto/src/hp.rs index d901debfe159..29accbc8555d 100644 --- a/third_party/rust/neqo-crypto/src/hp.rs +++ b/third_party/rust/neqo-crypto/src/hp.rs @@ -29,20 +29,16 @@ experimental_api!(SSL_HkdfExpandLabelWithMech( secret: *mut *mut PK11SymKey, )); -#[derive(Clone)] pub struct HpKey(SymKey); impl Debug for HpKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "HP-{:?}", self.0) + f.write_str("HP Key") } } impl HpKey { /// QUIC-specific API for extracting a header-protection key. - /// - /// # Errors - /// Errors if HKDF fails or if the label is too long to fit in a `c_uint`. pub fn extract(version: Version, cipher: Cipher, prk: &SymKey, label: &str) -> Res { let l = label.as_bytes(); let mut secret: *mut PK11SymKey = null_mut(); @@ -77,10 +73,6 @@ impl HpKey { } /// Generate a header protection mask for QUIC. - /// - /// # Errors - /// An error is returned if the NSS functions fail; a sample of the - /// wrong size is the obvious cause. #[allow(clippy::cast_sign_loss)] pub fn mask(&self, sample: &[u8]) -> Res> { let k: *mut PK11SymKey = *self.0; diff --git a/third_party/rust/neqo-crypto/src/lib.rs b/third_party/rust/neqo-crypto/src/lib.rs index 794a5b25c65d..90f2d16adde8 100644 --- a/third_party/rust/neqo-crypto/src/lib.rs +++ b/third_party/rust/neqo-crypto/src/lib.rs @@ -41,7 +41,7 @@ pub use self::agent::{ pub use self::constants::*; pub use self::err::{Error, PRErrorCode, Res}; pub use self::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}; -pub use self::p11::{random, SymKey}; +pub use self::p11::SymKey; pub use self::replay::AntiReplay; pub use self::secrets::SecretDirection; pub use auth::AuthenticationStatus; @@ -104,18 +104,6 @@ pub fn init() { } } -/// This enables SSLTRACE by calling a simple, harmless function to trigger its -/// side effects. SSLTRACE is not enabled in NSS until a socket is made or -/// global options are accessed. Reading an option is the least impact approach. -/// This allows us to use SSLTRACE in all of our unit tests and programs. -#[cfg(debug_assertions)] -fn enable_ssl_trace() { - let opt = ssl::Opt::Locking.as_int(); - let mut _v: ::std::os::raw::c_int = 0; - secstatus_to_res(unsafe { ssl::SSL_OptionGetDefault(opt, &mut _v) }) - .expect("SSL_OptionGetDefault failed"); -} - pub fn init_db>(dir: P) { time::init(); unsafe { @@ -147,9 +135,6 @@ pub fn init_db>(dir: P) { )) .expect("SSL_ConfigServerSessionIDCache failed"); - #[cfg(debug_assertions)] - enable_ssl_trace(); - NssLoaded::Db(path.into_boxed_path()) }); } diff --git a/third_party/rust/neqo-crypto/src/p11.rs b/third_party/rust/neqo-crypto/src/p11.rs index e3d4081a7739..4e3b23b23a97 100644 --- a/third_party/rust/neqo-crypto/src/p11.rs +++ b/third_party/rust/neqo-crypto/src/p11.rs @@ -67,9 +67,6 @@ scoped_ptr!(Slot, PK11SlotInfo, PK11_FreeSlot); impl SymKey { /// You really don't want to use this. - /// - /// # Errors - /// Internal errors in case of failures in NSS. pub fn as_bytes<'a>(&'a self) -> Res<&'a [u8]> { secstatus_to_res(unsafe { PK11_ExtractKeyValue(self.ptr) })?; @@ -82,15 +79,6 @@ impl SymKey { } } -impl Clone for SymKey { - #[must_use] - fn clone(&self) -> Self { - let ptr = unsafe { PK11_ReferenceSymKey(self.ptr) }; - assert!(!ptr.is_null()); - Self { ptr } - } -} - impl std::fmt::Debug for SymKey { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { if let Ok(b) = self.as_bytes() { @@ -102,14 +90,10 @@ impl std::fmt::Debug for SymKey { } /// Generate a randomized buffer. -#[must_use] -pub fn random(size: usize) -> Vec { +pub fn random(size: usize) -> Res> { let mut buf = vec![0; size]; - secstatus_to_res(unsafe { - PK11_GenerateRandom(buf.as_mut_ptr(), buf.len().try_into().unwrap()) - }) - .unwrap(); - buf + secstatus_to_res(unsafe { PK11_GenerateRandom(buf.as_mut_ptr(), buf.len().try_into()?) })?; + Ok(buf) } #[cfg(test)] diff --git a/third_party/rust/neqo-crypto/src/prio.rs b/third_party/rust/neqo-crypto/src/prio.rs index 6c55bdd082f1..01bef67f7026 100644 --- a/third_party/rust/neqo-crypto/src/prio.rs +++ b/third_party/rust/neqo-crypto/src/prio.rs @@ -4,8 +4,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#![allow(dead_code, non_upper_case_globals, non_snake_case)] -#![allow(clippy::cognitive_complexity, clippy::empty_enum, clippy::too_many_lines)] +#![allow(dead_code)] +#![allow(non_upper_case_globals)] +#![allow(non_snake_case)] +#![allow(clippy::cognitive_complexity)] +#![allow(clippy::empty_enum)] include!(concat!(env!("OUT_DIR"), "/nspr_io.rs")); diff --git a/third_party/rust/neqo-crypto/src/replay.rs b/third_party/rust/neqo-crypto/src/replay.rs index 72d902ec2aee..cb59ccb14a72 100644 --- a/third_party/rust/neqo-crypto/src/replay.rs +++ b/third_party/rust/neqo-crypto/src/replay.rs @@ -49,10 +49,6 @@ pub struct AntiReplay { impl AntiReplay { /// Make a new anti-replay context. /// See the documentation in NSS for advice on how to set these values. - /// - /// # Errors - /// Returns an error if `now` is in the past relative to our baseline or - /// NSS is unable to generate an anti-replay context. pub fn new(now: Instant, window: Duration, k: usize, bits: usize) -> Res { let mut ctx: *mut SSLAntiReplayContext = null_mut(); unsafe { diff --git a/third_party/rust/neqo-crypto/src/result.rs b/third_party/rust/neqo-crypto/src/result.rs index e6148dd054a4..2d69eb7927b1 100644 --- a/third_party/rust/neqo-crypto/src/result.rs +++ b/third_party/rust/neqo-crypto/src/result.rs @@ -90,7 +90,7 @@ mod tests { #[test] fn is_err_zero_code() { - // This code doesn't work without initializing NSS first. + // This code doesn't work without initializing NSS first. fixture_init(); set_error_code(0); diff --git a/third_party/rust/neqo-crypto/src/secrets.rs b/third_party/rust/neqo-crypto/src/secrets.rs index efc10de0490b..368326d60f93 100644 --- a/third_party/rust/neqo-crypto/src/secrets.rs +++ b/third_party/rust/neqo-crypto/src/secrets.rs @@ -4,15 +4,14 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::agentio::as_c_void; use crate::constants::*; use crate::err::Res; use crate::p11::{PK11SymKey, PK11_ReferenceSymKey, SymKey}; use crate::ssl::{PRFileDesc, SSLSecretCallback, SSLSecretDirection}; use neqo_common::qdebug; +use std::ops::Deref; use std::os::raw::c_void; -use std::pin::Pin; use std::ptr::NonNull; experimental_api!(SSL_SecretCallback( @@ -46,7 +45,7 @@ pub struct DirectionalSecrets { } impl DirectionalSecrets { - fn put(&mut self, epoch: Epoch, key: SymKey) { + pub fn put(&mut self, epoch: Epoch, key: SymKey) { assert!(epoch > 0); let i = (epoch - 1) as usize; assert!(i < self.secrets.len()); @@ -54,11 +53,11 @@ impl DirectionalSecrets { self.secrets[i] = Some(key); } - pub fn take(&mut self, epoch: Epoch) -> Option { + pub fn get(&self, epoch: Epoch) -> Option<&SymKey> { assert!(epoch > 0); let i = (epoch - 1) as usize; assert!(i < self.secrets.len()); - self.secrets[i].take() + self.secrets[i].as_ref() } } @@ -87,10 +86,10 @@ impl Secrets { None => panic!("NSS shouldn't be passing out NULL secrets"), Some(p) => SymKey::new(p), }; - self.put(SecretDirection::from(dir), epoch, key); + self.put(dir.into(), epoch, key); } - fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { + pub fn put(&mut self, dir: SecretDirection, epoch: Epoch, key: SymKey) { qdebug!("{:?} secret available for {:?}", dir, epoch); let keys = match dir { SecretDirection::Read => &mut self.r, @@ -98,34 +97,33 @@ impl Secrets { }; keys.put(epoch, key); } + + pub fn read(&self) -> &DirectionalSecrets { + &self.r + } + + pub fn write(&self) -> &DirectionalSecrets { + &self.w + } } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct SecretHolder { - secrets: Pin>, + secrets: Box, } impl SecretHolder { /// This registers with NSS. The lifetime of this object needs to match the lifetime /// of the connection, or bad things might happen. pub fn register(&mut self, fd: *mut PRFileDesc) -> Res<()> { - let p = as_c_void(&mut self.secrets); - unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p) } - } - - pub fn take_read(&mut self, epoch: Epoch) -> Option { - self.secrets.r.take(epoch) - } - - pub fn take_write(&mut self, epoch: Epoch) -> Option { - self.secrets.w.take(epoch) + let p = &*self.secrets as *const Secrets as *const c_void; + unsafe { SSL_SecretCallback(fd, Some(Secrets::secret_available), p as *mut c_void) } } } -impl Default for SecretHolder { - fn default() -> Self { - Self { - secrets: Box::pin(Secrets::default()), - } +impl Deref for SecretHolder { + type Target = Secrets; + fn deref(&self) -> &Self::Target { + self.secrets.as_ref() } } diff --git a/third_party/rust/neqo-crypto/src/selfencrypt.rs b/third_party/rust/neqo-crypto/src/selfencrypt.rs index 0e0b83c8defb..f8ff37b01328 100644 --- a/third_party/rust/neqo-crypto/src/selfencrypt.rs +++ b/third_party/rust/neqo-crypto/src/selfencrypt.rs @@ -27,10 +27,9 @@ impl SelfEncrypt { const VERSION: u8 = 1; const SALT_LENGTH: usize = 16; - /// # Errors - /// Failure to generate a new HKDF key using NSS results in an error. pub fn new(version: Version, cipher: Cipher) -> Res { - let key = hkdf::generate_key(version, cipher)?; + let sz = hkdf::key_size(version, cipher)?; + let key = hkdf::generate_key(version, cipher, sz)?; Ok(Self { version, cipher, @@ -48,11 +47,9 @@ impl SelfEncrypt { } /// Rotate keys. This causes any previous key that is being held to be replaced by the current key. - /// - /// # Errors - /// Failure to generate a new HKDF key using NSS results in an error. pub fn rotate(&mut self) -> Res<()> { - let new_key = hkdf::generate_key(self.version, self.cipher)?; + let sz = hkdf::key_size(self.version, self.cipher)?; + let new_key = hkdf::generate_key(self.version, self.cipher, sz)?; self.old_key = Some(mem::replace(&mut self.key, new_key)); let (kid, _) = self.key_id.overflowing_add(1); self.key_id = kid; @@ -64,9 +61,6 @@ impl SelfEncrypt { /// the encrypted `plaintext`, plus a version number and salt. /// `aad` is only used as input to the AEAD, it is not included in the output; the /// caller is responsible for carrying the AAD as appropriate. - /// - /// # Errors - /// Failure to protect using NSS AEAD APIs produces an error. #[allow(clippy::similar_names)] // aad is similar to aead pub fn seal(&self, aad: &[u8], plaintext: &[u8]) -> Res> { // Format is: @@ -77,7 +71,7 @@ impl SelfEncrypt { // opaque aead_encrypted(plaintext)[length as expanded]; // }; // AAD covers the entire header, plus the value of the AAD parameter that is provided. - let salt = random(Self::SALT_LENGTH); + let salt = random(Self::SALT_LENGTH)?; let aead = self.make_aead(&self.key, &salt)?; let encoded_len = 2 + salt.len() + plaintext.len() + aead.expansion(); @@ -117,10 +111,6 @@ impl SelfEncrypt { } /// Open the protected `ciphertext`. - /// - /// # Errors - /// Returns an error when the self-encrypted object is invalid; - /// when the keys have been rotated; or when NSS fails. #[allow(clippy::similar_names)] // aad is similar to aead pub fn open(&self, aad: &[u8], ciphertext: &[u8]) -> Res> { if ciphertext[0] != Self::VERSION { diff --git a/third_party/rust/neqo-crypto/src/ssl.rs b/third_party/rust/neqo-crypto/src/ssl.rs index ee5efe616e8d..b4b1544213ff 100644 --- a/third_party/rust/neqo-crypto/src/ssl.rs +++ b/third_party/rust/neqo-crypto/src/ssl.rs @@ -4,8 +4,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#![allow(dead_code, non_upper_case_globals, non_snake_case)] -#![allow(clippy::cognitive_complexity, clippy::too_many_lines)] +#![allow(dead_code)] +#![allow(non_upper_case_globals)] +#![allow(non_snake_case)] +#![allow(clippy::cognitive_complexity)] use crate::constants::*; diff --git a/third_party/rust/neqo-crypto/tests/aead.rs b/third_party/rust/neqo-crypto/tests/aead.rs index e350cdcde2c2..65fca96d040a 100644 --- a/third_party/rust/neqo-crypto/tests/aead.rs +++ b/third_party/rust/neqo-crypto/tests/aead.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::aead::Aead; use neqo_crypto::constants::*; diff --git a/third_party/rust/neqo-crypto/tests/agent.rs b/third_party/rust/neqo-crypto/tests/agent.rs index e3f7024922a4..a53f9c5584e4 100644 --- a/third_party/rust/neqo-crypto/tests/agent.rs +++ b/third_party/rust/neqo-crypto/tests/agent.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::*; @@ -49,7 +48,7 @@ fn basic() { // Calling handshake() again indicates that we're happy with the cert. let bytes = client.handshake(now(), &[]).expect("send CF"); assert!(!bytes.is_empty()); - assert!(client.state().is_connected()); + assert!(client.state().connected()); let client_info = client.info().expect("got info"); assert_eq!(TLS_VERSION_1_3, client_info.version()); @@ -57,14 +56,14 @@ fn basic() { let bytes = server.handshake(now(), &bytes[..]).expect("finish"); assert!(bytes.is_empty()); - assert!(server.state().is_connected()); + assert!(server.state().connected()); let server_info = server.info().expect("got info"); assert_eq!(TLS_VERSION_1_3, server_info.version()); assert_eq!(TLS_AES_128_GCM_SHA256, server_info.cipher_suite()); } -fn check_client_preinfo(client_preinfo: &SecretAgentPreInfo) { +fn check_client_preinfo(client_preinfo: SecretAgentPreInfo) { assert_eq!(client_preinfo.version(), None); assert_eq!(client_preinfo.cipher_suite(), None); assert_eq!(client_preinfo.early_data(), false); @@ -73,7 +72,7 @@ fn check_client_preinfo(client_preinfo: &SecretAgentPreInfo) { assert_eq!(client_preinfo.alpn(), None); } -fn check_server_preinfo(server_preinfo: &SecretAgentPreInfo) { +fn check_server_preinfo(server_preinfo: SecretAgentPreInfo) { assert_eq!(server_preinfo.version(), Some(TLS_VERSION_1_3)); assert_eq!(server_preinfo.cipher_suite(), Some(TLS_AES_128_GCM_SHA256)); assert_eq!(server_preinfo.early_data(), false); @@ -94,14 +93,14 @@ fn raw() { assert!(!client_records.is_empty()); assert_eq!(*client.state(), HandshakeState::InProgress); - check_client_preinfo(&client.preinfo().expect("get preinfo")); + check_client_preinfo(client.preinfo().expect("get preinfo")); let server_records = forward_records(now(), &mut server, client_records).expect("read CH, send SH"); assert!(!server_records.is_empty()); assert_eq!(*server.state(), HandshakeState::InProgress); - check_server_preinfo(&server.preinfo().expect("get preinfo")); + check_server_preinfo(server.preinfo().expect("get preinfo")); let client_records = forward_records(now(), &mut client, server_records).expect("send CF"); assert!(client_records.is_empty()); @@ -113,11 +112,11 @@ fn raw() { // Calling handshake() again indicates that we're happy with the cert. let client_records = client.handshake_raw(now(), None).expect("send CF"); assert!(!client_records.is_empty()); - assert!(client.state().is_connected()); + assert!(client.state().connected()); let server_records = forward_records(now(), &mut server, client_records).expect("finish"); assert!(server_records.is_empty()); - assert!(server.state().is_connected()); + assert!(server.state().connected()); // The client should have one certificate for the server. let mut certs = client.peer_certificate().unwrap(); @@ -354,21 +353,3 @@ fn reject_zero_rtt() { assert!(!client.info().unwrap().early_data_accepted()); assert!(!server.info().unwrap().early_data_accepted()); } - -#[test] -fn close() { - let mut client = Client::new("server.example").expect("should create client"); - let mut server = Server::new(&["key"]).expect("should create server"); - connect(&mut client, &mut server); - client.close(); - server.close(); -} - -#[test] -fn close_client_twice() { - let mut client = Client::new("server.example").expect("should create client"); - let mut server = Server::new(&["key"]).expect("should create server"); - connect(&mut client, &mut server); - client.close(); - client.close(); // Should be a noop. -} diff --git a/third_party/rust/neqo-crypto/tests/ext.rs b/third_party/rust/neqo-crypto/tests/ext.rs index d55ca8704bb5..e775a8e8210b 100644 --- a/third_party/rust/neqo-crypto/tests/ext.rs +++ b/third_party/rust/neqo-crypto/tests/ext.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::*; use std::cell::RefCell; @@ -60,10 +59,10 @@ impl ExtensionHandler for SimpleExtensionHandler { self.handled = true; if d.len() != 1 { ExtensionHandlerResult::Alert(50) // decode_error - } else if d[0] == 77 { - ExtensionHandlerResult::Ok - } else { + } else if d[0] != 77 { ExtensionHandlerResult::Alert(47) // illegal_parameter + } else { + ExtensionHandlerResult::Ok } } _ => ExtensionHandlerResult::Alert(110), // unsupported_extension diff --git a/third_party/rust/neqo-crypto/tests/handshake.rs b/third_party/rust/neqo-crypto/tests/handshake.rs index 0a44661d6601..2d12346ed501 100644 --- a/third_party/rust/neqo-crypto/tests/handshake.rs +++ b/third_party/rust/neqo-crypto/tests/handshake.rs @@ -16,7 +16,7 @@ pub fn forward_records( _ => HandshakeState::InProgress, }; let mut records_out = RecordList::default(); - for record in records_in { + for record in records_in.into_iter() { assert_eq!(records_out.len(), 0); assert_eq!(*agent.state(), expected_state); @@ -30,14 +30,18 @@ fn handshake(now: Instant, client: &mut SecretAgent, server: &mut SecretAgent) { let mut a = client; let mut b = server; let mut records = a.handshake_raw(now, None).unwrap(); - let is_done = |agent: &mut SecretAgent| agent.state().is_final(); + let is_done = |agent: &mut SecretAgent| match *agent.state() { + HandshakeState::Complete(_) | HandshakeState::Failed(_) => true, + _ => false, + }; while !is_done(b) { - records = if let Ok(r) = forward_records(now, &mut b, records) { - r - } else { - // TODO(mt) take the alert generated by the failed handshake - // and allow it to be sent to the peer. - return; + records = match forward_records(now, &mut b, records) { + Ok(r) => r, + _ => { + // TODO(mt) take the alert generated by the failed handshake + // and allow it to be sent to the peer. + return; + } }; if *b.state() == HandshakeState::AuthenticationPending { @@ -52,8 +56,8 @@ pub fn connect_at(now: Instant, client: &mut SecretAgent, server: &mut SecretAge handshake(now, client, server); qinfo!("client: {:?}", client.state()); qinfo!("server: {:?}", server.state()); - assert!(client.state().is_connected()); - assert!(server.state().is_connected()); + assert!(client.state().connected()); + assert!(server.state().connected()); } pub fn connect(client: &mut SecretAgent, server: &mut SecretAgent) { @@ -62,8 +66,8 @@ pub fn connect(client: &mut SecretAgent, server: &mut SecretAgent) { pub fn connect_fail(client: &mut SecretAgent, server: &mut SecretAgent) { handshake(now(), client, server); - assert!(!client.state().is_connected()); - assert!(!server.state().is_connected()); + assert!(!client.state().connected()); + assert!(!server.state().connected()); } #[derive(Clone, Copy, Debug)] @@ -81,7 +85,7 @@ pub struct PermissiveZeroRttChecker { impl Default for PermissiveZeroRttChecker { fn default() -> Self { - Self { resuming: true } + PermissiveZeroRttChecker { resuming: true } } } diff --git a/third_party/rust/neqo-crypto/tests/hkdf.rs b/third_party/rust/neqo-crypto/tests/hkdf.rs index 7d86bfeecc46..364c69498809 100644 --- a/third_party/rust/neqo-crypto/tests/hkdf.rs +++ b/third_party/rust/neqo-crypto/tests/hkdf.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::constants::*; use neqo_crypto::{hkdf, SymKey}; diff --git a/third_party/rust/neqo-crypto/tests/hp.rs b/third_party/rust/neqo-crypto/tests/hp.rs index 6922f6263a02..0183c9da6b1b 100644 --- a/third_party/rust/neqo-crypto/tests/hp.rs +++ b/third_party/rust/neqo-crypto/tests/hp.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::constants::*; use neqo_crypto::hkdf; @@ -14,35 +13,37 @@ fn make_hp(cipher: Cipher) -> HpKey { #[test] fn aes128() { - const EXPECTED: &[u8] = &[ - 0x04, 0x7b, 0xda, 0x65, 0xc3, 0x41, 0xcf, 0xbc, 0x5d, 0xe1, 0x75, 0x2b, 0x9d, 0x7d, 0xc3, - 0x14, - ]; - fixture_init(); let mask = make_hp(TLS_AES_128_GCM_SHA256) .mask(&[0; 16]) .expect("should produce a mask"); + const EXPECTED: &[u8] = &[ + 0x04, 0x7b, 0xda, 0x65, 0xc3, 0x41, 0xcf, 0xbc, 0x5d, 0xe1, 0x75, 0x2b, 0x9d, 0x7d, 0xc3, + 0x14, + ]; assert_eq!(mask, EXPECTED); } #[test] fn aes256() { - const EXPECTED: &[u8] = &[ - 0xb5, 0xea, 0xa2, 0x1c, 0x25, 0x77, 0x48, 0x18, 0xbf, 0x25, 0xea, 0xfa, 0xbd, 0x8d, 0x80, - 0x2b, - ]; - fixture_init(); let mask = make_hp(TLS_AES_256_GCM_SHA384) .mask(&[0; 16]) .expect("should produce a mask"); + const EXPECTED: &[u8] = &[ + 0xb5, 0xea, 0xa2, 0x1c, 0x25, 0x77, 0x48, 0x18, 0xbf, 0x25, 0xea, 0xfa, 0xbd, 0x8d, 0x80, + 0x2b, + ]; assert_eq!(mask, EXPECTED); } #[cfg(feature = "chacha")] #[test] fn chacha20_ctr() { + fixture_init(); + let mask = make_hp(TLS_CHACHA20_POLY1305_SHA256) + .mask(&[0; 16]) + .expect("should produce a mask"); const EXPECTED: &[u8] = &[ 0x34, 0x11, 0xb3, 0x53, 0x02, 0x0b, 0x16, 0xda, 0x0a, 0x85, 0x5a, 0x52, 0x0d, 0x06, 0x07, 0x1f, 0x4a, 0xb1, 0xaf, 0xf7, 0x83, 0xa8, 0xf0, 0x29, 0xc3, 0x19, 0xef, 0x57, 0x48, 0xe7, @@ -50,10 +51,5 @@ fn chacha20_ctr() { 0xf1, 0x62, 0x2f, 0x1e, 0xad, 0xdd, 0x23, 0x59, 0x45, 0xac, 0xd2, 0x19, 0x8a, 0xb4, 0x1f, 0x2f, 0x52, 0x46, 0x89, ]; - - fixture_init(); - let mask = make_hp(TLS_CHACHA20_POLY1305_SHA256) - .mask(&[0; 16]) - .expect("should produce a mask"); assert_eq!(mask, EXPECTED); } diff --git a/third_party/rust/neqo-crypto/tests/init.rs b/third_party/rust/neqo-crypto/tests/init.rs index 1e00f033d8ae..7b6ac80457d9 100644 --- a/third_party/rust/neqo-crypto/tests/init.rs +++ b/third_party/rust/neqo-crypto/tests/init.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] // This uses external interfaces to neqo_crypto rather than being a module // inside of lib.rs. Because all other code uses the test_fixture module, @@ -12,9 +11,8 @@ use neqo_crypto::*; // Pull in the NSS internals so that we can ask NSS if it thinks that // it is properly initialized. -#[allow(dead_code, non_upper_case_globals)] -#[allow(clippy::redundant_static_lifetimes, clippy::unseparated_literal_suffix)] mod nss { + #![allow(clippy::redundant_static_lifetimes, dead_code, non_upper_case_globals)] include!(concat!(env!("OUT_DIR"), "/nss_init.rs")); } diff --git a/third_party/rust/neqo-crypto/tests/selfencrypt.rs b/third_party/rust/neqo-crypto/tests/selfencrypt.rs index 78dd323e1106..17460d622457 100644 --- a/third_party/rust/neqo-crypto/tests/selfencrypt.rs +++ b/third_party/rust/neqo-crypto/tests/selfencrypt.rs @@ -1,5 +1,4 @@ #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_crypto::constants::*; use neqo_crypto::{init, selfencrypt::SelfEncrypt, Error}; diff --git a/third_party/rust/neqo-http3/.cargo-checksum.json b/third_party/rust/neqo-http3/.cargo-checksum.json index dca4cbe07101..68eff8ce650d 100644 --- a/third_party/rust/neqo-http3/.cargo-checksum.json +++ b/third_party/rust/neqo-http3/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"418e6d9c20e2ae00d840b2dc57a94bec093194e4974a3b25863a6656be5d4257","src/client_events.rs":"8e77e6e92c3d5933621f2baee3baacab230486ad8b6df1eca321ea74ed7cdcbd","src/connection.rs":"8499ea115fc061eb5d2eedb0a5cac6069a255ad756e6e89ce2f6e6a8dc5772fc","src/connection_client.rs":"2c8e7ffc5b67defef0e2a43b0053d1044b2ce4e83cadf1d3ee4d1e313cec35ff","src/connection_server.rs":"a32edf220f4664dccfc80141128a211d9997d86e8e988726c1380836015f1d0e","src/control_stream_local.rs":"319f8277fc4765b31a4a094bfd663e681d71831532925b0043bae5da96202e64","src/control_stream_remote.rs":"c205633af8539cd55f289071c6845b5bb2b0a9778f15976829c5d4a492360e19","src/hframe.rs":"8733c3af83da5ddbc6aa238710662fdba7f790bf266d242b40d727f68315d1cc","src/hsettings_frame.rs":"349a4413ce13f03e05264e6c4b22d231276a1c96e3983aada4478b038ec89dbc","src/lib.rs":"f47849cc5f47945c95aa58a9ed830ff512572512f0b3a7ddb6b545e9e16e08bf","src/server.rs":"212b98c363c0160304eaf02bb73dad6138236f52390ab7664ce4984657fdcca3","src/server_connection_events.rs":"d2b973a095f29cb0ac6fb84705165b034960d09b2dde7693bab96e6b802c6fba","src/server_events.rs":"f997bd329d45115f6a527eba8f0f1ecf21c0dd9a3184f08fc5002e34f4cfe2f0","src/stream_type_reader.rs":"da2b7b0358cb4829493cb964cae67c85e9efdf4127958aade7a56733ddc4f12e","src/transaction_client.rs":"65f0cea42843ad9057f587d6ef0a1751f46fe13db468904434ebaeb27f763a84","src/transaction_server.rs":"8603c3f835f680e2c63c1ed1b5962b208acd476a12e4bb7221d68b36c57c505a","tests/httpconn.rs":"7955f6ac4406b5d770e0fb10258aff529a1c01020374dfc5f85d8608abb68f6f"},"package":null} \ No newline at end of file +{"files":{"Cargo.toml":"a6bb6d8f914fc1199d18189cd7d83a1f2e54b9bca7ef40408fe7297827f81ea6","src/client_events.rs":"8e77e6e92c3d5933621f2baee3baacab230486ad8b6df1eca321ea74ed7cdcbd","src/connection.rs":"b5a34c2e519b9cfd931fa6305c9e85a13c5a10981a8b98a6c81b56efe8da7221","src/connection_client.rs":"c056e209d9ec306695152c31d43beeaf6c407511db89180a5c828c454339ce51","src/connection_server.rs":"eb5a11935a1d8ce04f29907c3146aca1dae1a2a0d44025d4a82d21bdddce72b7","src/control_stream_local.rs":"319f8277fc4765b31a4a094bfd663e681d71831532925b0043bae5da96202e64","src/control_stream_remote.rs":"1b96316d6eecc582436382517fcffdb2bb4f9d73194301bc3e2e253d3322e95e","src/hframe.rs":"5f6508d473b6cb5a6539be7db5e5f9f5ee8ee72bf166a351781a3990b67afda5","src/hsettings_frame.rs":"61bc427ece16ed7aa9d7040b0fb8f67d596c40c1c47f3711ed317bc973bfbc6e","src/lib.rs":"2c5b25a8aa7f1476b7ca5223d926a45afe1f8ca7d956221bdd924849d7c8e564","src/server.rs":"c4004faf8968c43fcb83387ea305b7c1a06ba5236f6f5b3cfdfd6e2f5542da04","src/server_connection_events.rs":"d2b973a095f29cb0ac6fb84705165b034960d09b2dde7693bab96e6b802c6fba","src/server_events.rs":"0ff69865aadede05aa6d9917b0ccbe2e44558a829a708d7cd58014887a099ad8","src/stream_type_reader.rs":"be1ea1f553292b5447f0d6d89bdfa73732189db236ce34b4067cda0276229540","src/transaction_client.rs":"9f11285df4f480b5ecbf1d5ab60d6aff5b62ee608f079f06ea3bd43e8f144d45","src/transaction_server.rs":"f4daf0389b437e9081bc30db8e25c56e8b9e2e2423465b7986cc3629f5cdc7b3","tests/httpconn.rs":"7955f6ac4406b5d770e0fb10258aff529a1c01020374dfc5f85d8608abb68f6f"},"package":null} \ No newline at end of file diff --git a/third_party/rust/neqo-http3/Cargo.toml b/third_party/rust/neqo-http3/Cargo.toml index 09c2a6a5714c..87b7042f8806 100644 --- a/third_party/rust/neqo-http3/Cargo.toml +++ b/third_party/rust/neqo-http3/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neqo-http3" -version = "0.1.13" +version = "0.1.12" authors = ["Dragana Damjanovic "] edition = "2018" license = "MIT/Apache-2.0" diff --git a/third_party/rust/neqo-http3/src/connection.rs b/third_party/rust/neqo-http3/src/connection.rs index 0dc1e7dec727..8a257df56488 100644 --- a/third_party/rust/neqo-http3/src/connection.rs +++ b/third_party/rust/neqo-http3/src/connection.rs @@ -85,7 +85,7 @@ impl Http3Connection { if max_table_size > (1 << 30) - 1 { panic!("Wrong max_table_size"); } - Self { + Http3Connection { state: Http3State::Initializing, local_settings: LocalSettings { max_table_size, @@ -341,7 +341,6 @@ impl Http3Connection { } pub fn handle_state_change(&mut self, conn: &mut Connection, state: &State) -> Res { - qdebug!([self], "Handle state change {:?}", state); match state { State::Connected => { debug_assert!(matches!( @@ -364,7 +363,7 @@ impl Http3Connection { } State::Closed(error) => { if !matches!(self.state, Http3State::Closed(_)) { - self.state = Http3State::Closed(error.clone().into()); + self.state = Http3State::Closing(error.clone().into()); Ok(true) } else { Ok(false) diff --git a/third_party/rust/neqo-http3/src/connection_client.rs b/third_party/rust/neqo-http3/src/connection_client.rs index 113bd7b5fc36..e7448d731593 100644 --- a/third_party/rust/neqo-http3/src/connection_client.rs +++ b/third_party/rust/neqo-http3/src/connection_client.rs @@ -44,7 +44,7 @@ impl Http3Client { max_table_size: u32, max_blocked_streams: u16, ) -> Res { - Ok(Self::new_with_conn( + Ok(Http3Client::new_with_conn( Connection::new_client(server_name, protocols, cid_manager, local_addr, remote_addr)?, max_table_size, max_blocked_streams, @@ -52,7 +52,7 @@ impl Http3Client { } pub fn new_with_conn(c: Connection, max_table_size: u32, max_blocked_streams: u16) -> Self { - Self { + Http3Client { conn: c, base_handler: Http3Connection::new(max_table_size, max_blocked_streams), events: Http3ClientEvents::default(), @@ -184,7 +184,6 @@ impl Http3Client { match transaction.read_response_headers() { Ok((headers, fin)) => { if transaction.done() { - qinfo!([self], "read_response_headers transaction done"); self.base_handler.transactions.remove(&stream_id); } Ok((headers, fin)) @@ -568,7 +567,7 @@ mod tests { assert_eq!(client.state(), Http3State::Connected); let _ = server.conn.process(out.dgram(), now()); - assert!(server.conn.state().connected()); + assert_eq!(*server.conn.state(), State::Connected); } // Perform only Quic transport handshake. @@ -2483,7 +2482,7 @@ mod tests { assert_eq!(client.state(), Http3State::Connected); let _out = server.conn.process(out.dgram(), now()); - assert!(server.conn.state().connected()); + assert_eq!(*server.conn.state(), State::Connected); assert!(client.tls_info().unwrap().resumed()); assert!(server.conn.tls_info().unwrap().resumed()); @@ -2516,7 +2515,7 @@ mod tests { let out = client.process(out.dgram(), now()); assert_eq!(client.state(), Http3State::Connected); let out = server.conn.process(out.dgram(), now()); - assert!(server.conn.state().connected()); + assert_eq!(*server.conn.state(), State::Connected); let out = client.process(out.dgram(), now()); assert!(out.as_dgram_ref().is_none()); @@ -2608,6 +2607,7 @@ mod tests { original_settings: &[HSetting], resumption_settings: &[HSetting], expected_client_state: Http3State, + expected_server_state: State, expected_encoder_stream_data: &[u8], ) { let mut client = default_http3_client(); @@ -2643,7 +2643,7 @@ mod tests { assert_eq!(client.state(), Http3State::Connected); let _out = server.conn.process(out.dgram(), now()); - assert!(server.conn.state().connected()); + assert_eq!(*server.conn.state(), State::Connected); assert!(client.tls_info().unwrap().resumed()); assert!(server.conn.tls_info().unwrap().resumed()); @@ -2661,7 +2661,7 @@ mod tests { client.process(out.dgram(), now()); assert_eq!(client.state(), expected_client_state); - assert!(server.conn.state().connected()); + assert_eq!(*server.conn.state(), expected_server_state); } #[test] @@ -2679,6 +2679,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2697,6 +2698,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2715,6 +2717,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2733,6 +2736,7 @@ mod tests { HSetting::new(HSettingType::BlockedStreams, 100), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2752,6 +2756,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2771,6 +2776,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2790,6 +2796,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2809,6 +2816,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2828,6 +2836,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 20000), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2847,6 +2856,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 5000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2866,6 +2876,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA, ); } @@ -2885,6 +2896,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Connected, + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } @@ -2904,6 +2916,7 @@ mod tests { HSetting::new(HSettingType::MaxHeaderListSize, 10000), ], Http3State::Closing(CloseError::Application(265)), + State::Connected, ENCODER_STREAM_DATA_WITH_CAP_INSTRUCTION, ); } diff --git a/third_party/rust/neqo-http3/src/connection_server.rs b/third_party/rust/neqo-http3/src/connection_server.rs index 93daed65dd77..55dbaa4e3553 100644 --- a/third_party/rust/neqo-http3/src/connection_server.rs +++ b/third_party/rust/neqo-http3/src/connection_server.rs @@ -27,7 +27,7 @@ impl ::std::fmt::Display for Http3ServerHandler { impl Http3ServerHandler { pub fn new(max_table_size: u32, max_blocked_streams: u16) -> Self { - Self { + Http3ServerHandler { base_handler: Http3Connection::new(max_table_size, max_blocked_streams), events: Http3ServerConnEvents::default(), } diff --git a/third_party/rust/neqo-http3/src/control_stream_remote.rs b/third_party/rust/neqo-http3/src/control_stream_remote.rs index 6ea2813dfc13..b54404d8485e 100644 --- a/third_party/rust/neqo-http3/src/control_stream_remote.rs +++ b/third_party/rust/neqo-http3/src/control_stream_remote.rs @@ -24,8 +24,8 @@ impl ::std::fmt::Display for ControlStreamRemote { } impl ControlStreamRemote { - pub fn new() -> Self { - Self { + pub fn new() -> ControlStreamRemote { + ControlStreamRemote { stream_id: None, frame_reader: HFrameReader::new(), fin: false, diff --git a/third_party/rust/neqo-http3/src/hframe.rs b/third_party/rust/neqo-http3/src/hframe.rs index 5aaebb17b56b..7f31b19b2504 100644 --- a/third_party/rust/neqo-http3/src/hframe.rs +++ b/third_party/rust/neqo-http3/src/hframe.rs @@ -65,14 +65,14 @@ pub enum HFrame { impl HFrame { fn get_type(&self) -> HFrameType { match self { - Self::Data { .. } => H3_FRAME_TYPE_DATA, - Self::Headers { .. } => H3_FRAME_TYPE_HEADERS, - Self::CancelPush { .. } => H3_FRAME_TYPE_CANCEL_PUSH, - Self::Settings { .. } => H3_FRAME_TYPE_SETTINGS, - Self::PushPromise { .. } => H3_FRAME_TYPE_PUSH_PROMISE, - Self::Goaway { .. } => H3_FRAME_TYPE_GOAWAY, - Self::MaxPushId { .. } => H3_FRAME_TYPE_MAX_PUSH_ID, - Self::DuplicatePush { .. } => H3_FRAME_TYPE_DUPLICATE_PUSH, + HFrame::Data { .. } => H3_FRAME_TYPE_DATA, + HFrame::Headers { .. } => H3_FRAME_TYPE_HEADERS, + HFrame::CancelPush { .. } => H3_FRAME_TYPE_CANCEL_PUSH, + HFrame::Settings { .. } => H3_FRAME_TYPE_SETTINGS, + HFrame::PushPromise { .. } => H3_FRAME_TYPE_PUSH_PROMISE, + HFrame::Goaway { .. } => H3_FRAME_TYPE_GOAWAY, + HFrame::MaxPushId { .. } => H3_FRAME_TYPE_MAX_PUSH_ID, + HFrame::DuplicatePush { .. } => H3_FRAME_TYPE_DUPLICATE_PUSH, } } @@ -80,19 +80,19 @@ impl HFrame { enc.encode_varint(self.get_type()); match self { - Self::Data { len } | Self::Headers { len } => { + HFrame::Data { len } | HFrame::Headers { len } => { // DATA and HEADERS frames only encode the length here. enc.encode_varint(*len); } - Self::CancelPush { push_id } => { + HFrame::CancelPush { push_id } => { enc.encode_vvec_with(|enc_inner| { enc_inner.encode_varint(*push_id); }); } - Self::Settings { settings } => { + HFrame::Settings { settings } => { settings.encode_frame_contents(enc); } - Self::PushPromise { + HFrame::PushPromise { push_id, header_block, } => { @@ -100,17 +100,17 @@ impl HFrame { enc.encode_varint(*push_id); enc.encode(header_block); } - Self::Goaway { stream_id } => { + HFrame::Goaway { stream_id } => { enc.encode_vvec_with(|enc_inner| { enc_inner.encode_varint(*stream_id); }); } - Self::MaxPushId { push_id } => { + HFrame::MaxPushId { push_id } => { enc.encode_vvec_with(|enc_inner| { enc_inner.encode_varint(*push_id); }); } - Self::DuplicatePush { push_id } => { + HFrame::DuplicatePush { push_id } => { enc.encode_vvec_with(|enc_inner| { enc_inner.encode_varint(*push_id); }); @@ -120,14 +120,14 @@ impl HFrame { pub fn is_allowed(&self, s: HStreamType) -> bool { match self { - Self::Data { .. } => !(s == HStreamType::Control), - Self::Headers { .. } => !(s == HStreamType::Control), - Self::CancelPush { .. } => (s == HStreamType::Control), - Self::Settings { .. } => (s == HStreamType::Control), - Self::PushPromise { .. } => (s == HStreamType::Request), - Self::Goaway { .. } => (s == HStreamType::Control), - Self::MaxPushId { .. } => (s == HStreamType::Control), - Self::DuplicatePush { .. } => (s == HStreamType::Request), + HFrame::Data { .. } => !(s == HStreamType::Control), + HFrame::Headers { .. } => !(s == HStreamType::Control), + HFrame::CancelPush { .. } => (s == HStreamType::Control), + HFrame::Settings { .. } => (s == HStreamType::Control), + HFrame::PushPromise { .. } => (s == HStreamType::Request), + HFrame::Goaway { .. } => (s == HStreamType::Control), + HFrame::MaxPushId { .. } => (s == HStreamType::Control), + HFrame::DuplicatePush { .. } => (s == HStreamType::Request), } } } @@ -158,8 +158,8 @@ impl Default for HFrameReader { } impl HFrameReader { - pub fn new() -> Self { - Self { + pub fn new() -> HFrameReader { + HFrameReader { state: HFrameReaderState::BeforeFrame, hframe_type: 0, hframe_len: 0, diff --git a/third_party/rust/neqo-http3/src/hsettings_frame.rs b/third_party/rust/neqo-http3/src/hsettings_frame.rs index 66aaa8b11e5e..9d3dd379cca8 100644 --- a/third_party/rust/neqo-http3/src/hsettings_frame.rs +++ b/third_party/rust/neqo-http3/src/hsettings_frame.rs @@ -37,7 +37,7 @@ pub struct HSetting { impl HSetting { pub fn new(setting_type: HSettingType, value: u64) -> Self { - Self { + HSetting { setting_type, value, } @@ -51,7 +51,7 @@ pub struct HSettings { impl HSettings { pub fn new(settings: &[HSetting]) -> Self { - Self { + HSettings { settings: settings.to_vec(), } } diff --git a/third_party/rust/neqo-http3/src/lib.rs b/third_party/rust/neqo-http3/src/lib.rs index b280aa3a799d..eb4e94290071 100644 --- a/third_party/rust/neqo-http3/src/lib.rs +++ b/third_party/rust/neqo-http3/src/lib.rs @@ -5,7 +5,6 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::use_self)] mod client_events; mod connection; @@ -23,9 +22,9 @@ mod transaction_client; pub mod transaction_server; //pub mod server; -use neqo_qpack::Error as QpackError; +use neqo_qpack; +use neqo_transport; pub use neqo_transport::Output; -use neqo_transport::{AppError, Error as TransportError}; pub use client_events::Http3ClientEvent; pub use connection::Http3State; @@ -64,84 +63,82 @@ pub enum Error { InvalidStreamId, NoMoreData, NotEnoughData, - TransportError(TransportError), + TransportError(neqo_transport::Error), Unavailable, Unexpected, InvalidResumptionToken, } impl Error { - pub fn code(&self) -> AppError { + pub fn code(&self) -> neqo_transport::AppError { match self { - Self::HttpNoError => 0x100, - Self::HttpGeneralProtocolError => 0x101, - Self::HttpInternalError => 0x102, - Self::HttpStreamCreationError => 0x103, - Self::HttpClosedCriticalStream => 0x104, - Self::HttpFrameUnexpected => 0x105, - Self::HttpFrameError => 0x106, - Self::HttpExcessiveLoad => 0x107, - Self::HttpIdError => 0x108, - Self::HttpSettingsError => 0x109, - Self::HttpMissingSettings => 0x10a, - Self::HttpRequestRejected => 0x10b, - Self::HttpRequestCancelled => 0x10c, - Self::HttpRequestIncomplete => 0x10d, - Self::HttpEarlyResponse => 0x10e, - Self::HttpConnectError => 0x10f, - Self::HttpVersionFallback => 0x110, - Self::QpackError(e) => e.code(), + Error::HttpNoError => 0x100, + Error::HttpGeneralProtocolError => 0x101, + Error::HttpInternalError => 0x102, + Error::HttpStreamCreationError => 0x103, + Error::HttpClosedCriticalStream => 0x104, + Error::HttpFrameUnexpected => 0x105, + Error::HttpFrameError => 0x106, + Error::HttpExcessiveLoad => 0x107, + Error::HttpIdError => 0x108, + Error::HttpSettingsError => 0x109, + Error::HttpMissingSettings => 0x10a, + Error::HttpRequestRejected => 0x10b, + Error::HttpRequestCancelled => 0x10c, + Error::HttpRequestIncomplete => 0x10d, + Error::HttpEarlyResponse => 0x10e, + Error::HttpConnectError => 0x10f, + Error::HttpVersionFallback => 0x110, + Error::QpackError(e) => e.code(), // These are all internal errors. _ => 3, } } -} -impl From for Error { - fn from(err: TransportError) -> Self { - Self::TransportError(err) - } -} - -impl From for Error { - fn from(err: QpackError) -> Self { - Self::QpackError(err) - } -} - -impl From for Error { - fn from(error: AppError) -> Self { + pub fn from_code(error: neqo_transport::AppError) -> Error { match error { - 0x100 => Self::HttpNoError, - 0x101 => Self::HttpGeneralProtocolError, - 0x102 => Self::HttpInternalError, - 0x103 => Self::HttpStreamCreationError, - 0x104 => Self::HttpClosedCriticalStream, - 0x105 => Self::HttpFrameUnexpected, - 0x106 => Self::HttpFrameError, - 0x107 => Self::HttpExcessiveLoad, - 0x108 => Self::HttpIdError, - 0x109 => Self::HttpSettingsError, - 0x10a => Self::HttpMissingSettings, - 0x10b => Self::HttpRequestRejected, - 0x10c => Self::HttpRequestCancelled, - 0x10d => Self::HttpRequestIncomplete, - 0x10e => Self::HttpEarlyResponse, - 0x10f => Self::HttpConnectError, - 0x110 => Self::HttpVersionFallback, - 0x200 => Self::QpackError(QpackError::DecompressionFailed), - 0x201 => Self::QpackError(QpackError::EncoderStreamError), - 0x202 => Self::QpackError(QpackError::DecoderStreamError), - _ => Self::HttpInternalError, + 0x100 => Error::HttpNoError, + 0x101 => Error::HttpGeneralProtocolError, + 0x102 => Error::HttpInternalError, + 0x103 => Error::HttpStreamCreationError, + 0x104 => Error::HttpClosedCriticalStream, + 0x105 => Error::HttpFrameUnexpected, + 0x106 => Error::HttpFrameError, + 0x107 => Error::HttpExcessiveLoad, + 0x108 => Error::HttpIdError, + 0x109 => Error::HttpSettingsError, + 0x10a => Error::HttpMissingSettings, + 0x10b => Error::HttpRequestRejected, + 0x10c => Error::HttpRequestCancelled, + 0x10d => Error::HttpRequestIncomplete, + 0x10e => Error::HttpEarlyResponse, + 0x10f => Error::HttpConnectError, + 0x110 => Error::HttpVersionFallback, + 0x200 => Error::QpackError(neqo_qpack::Error::DecompressionFailed), + 0x201 => Error::QpackError(neqo_qpack::Error::EncoderStreamError), + 0x202 => Error::QpackError(neqo_qpack::Error::DecoderStreamError), + _ => Error::HttpInternalError, } } } +impl From for Error { + fn from(err: neqo_transport::Error) -> Self { + Error::TransportError(err) + } +} + +impl From for Error { + fn from(err: neqo_qpack::Error) -> Self { + Error::QpackError(err) + } +} + impl ::std::error::Error for Error { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { match self { - Self::TransportError(e) => Some(e), - Self::QpackError(e) => Some(e), + Error::TransportError(e) => Some(e), + Error::QpackError(e) => Some(e), _ => None, } } diff --git a/third_party/rust/neqo-http3/src/server.rs b/third_party/rust/neqo-http3/src/server.rs index 8cbba0ba1b2e..4c1e0a10406c 100644 --- a/third_party/rust/neqo-http3/src/server.rs +++ b/third_party/rust/neqo-http3/src/server.rs @@ -43,8 +43,8 @@ impl Http3Server { cid_manager: Rc>, max_table_size: u32, max_blocked_streams: u16, - ) -> Res { - Ok(Self { + ) -> Res { + Ok(Http3Server { server: Server::new(now, certs, protocols, anti_replay, cid_manager)?, max_table_size, max_blocked_streams, diff --git a/third_party/rust/neqo-http3/src/server_events.rs b/third_party/rust/neqo-http3/src/server_events.rs index 0742c789f931..809150719b53 100644 --- a/third_party/rust/neqo-http3/src/server_events.rs +++ b/third_party/rust/neqo-http3/src/server_events.rs @@ -39,7 +39,7 @@ impl ClientRequestStream { handler: Rc>, stream_id: u64, ) -> Self { - Self { + ClientRequestStream { conn, handler, stream_id, diff --git a/third_party/rust/neqo-http3/src/stream_type_reader.rs b/third_party/rust/neqo-http3/src/stream_type_reader.rs index 850c3266c039..bee282c4c075 100644 --- a/third_party/rust/neqo-http3/src/stream_type_reader.rs +++ b/third_party/rust/neqo-http3/src/stream_type_reader.rs @@ -14,8 +14,8 @@ pub struct NewStreamTypeReader { } impl NewStreamTypeReader { - pub fn new() -> Self { - Self { + pub fn new() -> NewStreamTypeReader { + NewStreamTypeReader { reader: IncrementalDecoder::decode_varint(), fin: false, } diff --git a/third_party/rust/neqo-http3/src/transaction_client.rs b/third_party/rust/neqo-http3/src/transaction_client.rs index e6c266a7e74d..2052032aadb0 100644 --- a/third_party/rust/neqo-http3/src/transaction_client.rs +++ b/third_party/rust/neqo-http3/src/transaction_client.rs @@ -9,7 +9,7 @@ use crate::hframe::{HFrame, HFrameReader}; use crate::client_events::Http3ClientEvents; use crate::connection::Http3Transaction; use crate::Header; -use neqo_common::{qdebug, qinfo, qtrace, qwarn, Encoder}; +use neqo_common::{qdebug, qinfo, qtrace, Encoder}; use neqo_qpack::decoder::QPackDecoder; use neqo_qpack::encoder::QPackEncoder; use neqo_transport::Connection; @@ -36,8 +36,8 @@ struct Request { } impl Request { - pub fn new(method: &str, scheme: &str, host: &str, path: &str, headers: &[Header]) -> Self { - let mut r = Self { + pub fn new(method: &str, scheme: &str, host: &str, path: &str, headers: &[Header]) -> Request { + let mut r = Request { method: method.to_owned(), scheme: scheme.to_owned(), host: host.to_owned(), @@ -180,9 +180,9 @@ impl TransactionClient { path: &str, headers: &[Header], conn_events: Http3ClientEvents, - ) -> Self { + ) -> TransactionClient { qinfo!("Create a request stream_id={}", stream_id); - Self { + TransactionClient { send_state: TransactionSendState::SendingHeaders { request: Request::new(method, scheme, host, path, headers), fin: false, @@ -291,7 +291,6 @@ impl TransactionClient { HFrame::PushPromise { .. } => Err(Error::HttpIdError), HFrame::Headers { .. } => { // TODO implement trailers! - qwarn!([self], "Received trailers"); Err(Error::HttpFrameUnexpected) } _ => Err(Error::HttpFrameUnexpected), diff --git a/third_party/rust/neqo-http3/src/transaction_server.rs b/third_party/rust/neqo-http3/src/transaction_server.rs index 8c8b54096bd5..5ddda4668a25 100644 --- a/third_party/rust/neqo-http3/src/transaction_server.rs +++ b/third_party/rust/neqo-http3/src/transaction_server.rs @@ -42,9 +42,9 @@ pub struct TransactionServer { } impl TransactionServer { - pub fn new(stream_id: u64, conn_events: Http3ServerConnEvents) -> Self { + pub fn new(stream_id: u64, conn_events: Http3ServerConnEvents) -> TransactionServer { qinfo!("Create a request stream_id={}", stream_id); - Self { + TransactionServer { recv_state: TransactionRecvState::WaitingForHeaders, send_state: TransactionSendState::Initial, stream_id, diff --git a/third_party/rust/neqo-qpack/.cargo-checksum.json b/third_party/rust/neqo-qpack/.cargo-checksum.json index bb641564c634..1fbe1a2706bd 100644 --- a/third_party/rust/neqo-qpack/.cargo-checksum.json +++ b/third_party/rust/neqo-qpack/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"b7caad70be20f74848df1286119f9b9a7895eb07de9f6be0394fef5020fdd993","src/decoder.rs":"aac6d5b3dfb19779351c2568a4c54c551e2de83d0e458246c818a6af15514477","src/encoder.rs":"992bb211273d48b9d85ab4bc6bad5c0dbc5c12e7f9e7c1bb35f1b0db5eb7cffe","src/huffman.rs":"720eedace45205098a0b2210c876906ce15b7be469a799e75e70baafac8adee8","src/huffman_decode_helper.rs":"e4734353591770dfe9a9047b0be5d9068150433e9cea8cad029444b42b0afa39","src/huffman_table.rs":"06fea766a6276ac56c7ee0326faed800a742c15fda1f33bf2513e6cc6a5e6d27","src/lib.rs":"fa5b76f6b7db74904fe0317bbc1214292494365328c2efa06b4146cbd2ee6c1b","src/qpack_helper.rs":"200ab8bcb60728e3bcacf25b7006fa54b544458bfee5e66e09fa472a614347fc","src/qpack_send_buf.rs":"471e3b0af9f8783aa1bfe11a1959bf5694e62bc2d8e1cf783c933af81e3f3cf9","src/static_table.rs":"fda9d5c6f38f94b0bf92d3afdf8432dce6e27e189736596e16727090c77b78ec","src/table.rs":"1043a6e0761d9ff05a35dfab3b5a0e871d1b1666e83bc4fbd9e97383ca44e59e"},"package":null} \ No newline at end of file +{"files":{"Cargo.toml":"7c469ea56bf87154c0eebf67eeb253ff0172453250bb165d74c08e71272cfea5","src/decoder.rs":"a8e20a9f82846e873197c75d1b5ab49270014c807e90b1331ebd2a449d2d84e0","src/encoder.rs":"78da509611b5869d320795c42bef944b6499c0f207c73818c1908f1a1cf001fc","src/huffman.rs":"720eedace45205098a0b2210c876906ce15b7be469a799e75e70baafac8adee8","src/huffman_decode_helper.rs":"e4734353591770dfe9a9047b0be5d9068150433e9cea8cad029444b42b0afa39","src/huffman_table.rs":"06fea766a6276ac56c7ee0326faed800a742c15fda1f33bf2513e6cc6a5e6d27","src/lib.rs":"9895f91624d58388cf4906a80b3d8e9109abf24c9df542af8acb34b3a6e2231e","src/qpack_helper.rs":"200ab8bcb60728e3bcacf25b7006fa54b544458bfee5e66e09fa472a614347fc","src/qpack_send_buf.rs":"471e3b0af9f8783aa1bfe11a1959bf5694e62bc2d8e1cf783c933af81e3f3cf9","src/static_table.rs":"fda9d5c6f38f94b0bf92d3afdf8432dce6e27e189736596e16727090c77b78ec","src/table.rs":"f4f09692bf6ec863b0f066c88837d99f59a1fc4a8ca61bee4ed76d45a77c3cc4"},"package":null} \ No newline at end of file diff --git a/third_party/rust/neqo-qpack/Cargo.toml b/third_party/rust/neqo-qpack/Cargo.toml index b4cfa353dc6d..13c675594061 100644 --- a/third_party/rust/neqo-qpack/Cargo.toml +++ b/third_party/rust/neqo-qpack/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neqo-qpack" -version = "0.1.13" +version = "0.1.12" authors = ["Dragana Damjanovic "] edition = "2018" license = "MIT/Apache-2.0" diff --git a/third_party/rust/neqo-qpack/src/decoder.rs b/third_party/rust/neqo-qpack/src/decoder.rs index 4bd46f43638a..d982f64e7a26 100644 --- a/third_party/rust/neqo-qpack/src/decoder.rs +++ b/third_party/rust/neqo-qpack/src/decoder.rs @@ -4,6 +4,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +#![allow(unused_variables, dead_code)] use crate::huffman::Huffman; use crate::qpack_helper::{ read_prefixed_encoded_int_slice, read_prefixed_encoded_int_with_connection, BufWrapper, @@ -73,6 +74,7 @@ enum QPackDecoderState { pub struct QPackDecoder { state: QPackDecoderState, table: HeaderTable, + increment: u64, total_num_of_inserts: u64, max_entries: u64, send_buf: QPData, @@ -84,11 +86,12 @@ pub struct QPackDecoder { } impl QPackDecoder { - pub fn new(max_table_size: u32, max_blocked_streams: u16) -> Self { + pub fn new(max_table_size: u32, max_blocked_streams: u16) -> QPackDecoder { qdebug!("Decoder: creating a new qpack decoder."); - Self { + QPackDecoder { state: QPackDecoderState::ReadInstruction, table: HeaderTable::new(false), + increment: 0, total_num_of_inserts: 0, max_entries: (f64::from(max_table_size) / 32.0).floor() as u64, send_buf: QPData::default(), @@ -209,6 +212,7 @@ impl QPackDecoder { qdebug!([label], "received instruction - duplicate index={}", v); self.table.duplicate(v)?; self.total_num_of_inserts += 1; + self.increment += 1; self.state = QPackDecoderState::ReadInstruction; } else { self.state = QPackDecoderState::Duplicate { index: v, cnt }; @@ -310,9 +314,10 @@ impl QPackDecoder { self.table.insert_with_name_ref( *name_static_table, *name_index, - &value_to_insert, + value_to_insert, )?; self.total_num_of_inserts += 1; + self.increment += 1; self.state = QPackDecoderState::ReadInstruction; } else { // waiting for more data @@ -426,8 +431,9 @@ impl QPackDecoder { mem::swap(&mut value_to_insert, value); } qdebug!([label], "received instruction - insert with name literal name={:x?} value={:x?}", name_to_insert, value_to_insert); - self.table.insert(&name_to_insert, &value_to_insert)?; + self.table.insert(name_to_insert, value_to_insert)?; self.total_num_of_inserts += 1; + self.increment += 1; self.state = QPackDecoderState::ReadInstruction; } else { // waiting for more data @@ -447,6 +453,7 @@ impl QPackDecoder { qdebug!([label], "received instruction - duplicate index={}", index); self.table.duplicate(*index)?; self.total_num_of_inserts += 1; + self.increment += 1; self.state = QPackDecoderState::ReadInstruction; } else { // waiting for more data @@ -478,18 +485,13 @@ impl QPackDecoder { if cap > u64::from(self.max_table_size) { return Err(Error::EncoderStreamError); } - self.table - .set_capacity(cap) - .map_err(|_| Error::EncoderStreamError) + self.table.set_capacity(cap); + Ok(()) } - fn header_ack(&mut self, stream_id: u64, required_inserts: u64) { - let ack_increment_delta = required_inserts - self.table.get_acked_inserts_cnt(); + fn header_ack(&mut self, stream_id: u64) { self.send_buf .encode_prefixed_encoded_int(0x80, 1, stream_id); - self.table - .increment_acked(ack_increment_delta) - .expect("This should never happen"); } pub fn cancel_stream(&mut self, stream_id: u64) { @@ -499,13 +501,10 @@ impl QPackDecoder { pub fn send(&mut self, conn: &mut Connection) -> Res<()> { // Encode increment instruction if needed. - let ack_increment_delta = self.total_num_of_inserts - self.table.get_acked_inserts_cnt(); - if ack_increment_delta > 0 { + if self.increment > 0 { self.send_buf - .encode_prefixed_encoded_int(0x00, 2, ack_increment_delta); - self.table - .increment_acked(ack_increment_delta) - .expect("This should never happen"); + .encode_prefixed_encoded_int(0x00, 2, self.increment); + self.increment = 0; } if self.send_buf.len() == 0 { Ok(()) @@ -554,7 +553,7 @@ impl QPackDecoder { if reader.done() { // Send header_ack if req_inserts != 0 { - self.header_ack(stream_id, req_inserts); + self.header_ack(stream_id); } qdebug!([self], "done decoding header block."); break Ok(Some(h)); @@ -790,100 +789,65 @@ fn read_prefixed_encoded_int_with_connection_wrap( #[cfg(test)] mod tests { use super::*; + use neqo_transport::ConnectionEvent; use neqo_transport::StreamType; use std::convert::TryInto; use test_fixture::*; - struct TestDecoder { - decoder: QPackDecoder, - send_stream_id: u64, - recv_stream_id: u64, - conn: Connection, - peer_conn: Connection, - } - - fn connect() -> TestDecoder { - let (mut conn, mut peer_conn) = test_fixture::connect(); + fn connect() -> (QPackDecoder, Connection, Connection, u64, u64) { + let (mut conn_c, mut conn_s) = test_fixture::connect(); // create a stream - let recv_stream_id = peer_conn.stream_create(StreamType::UniDi).unwrap(); - let send_stream_id = conn.stream_create(StreamType::UniDi).unwrap(); + let recv_stream_id = conn_s.stream_create(StreamType::UniDi).unwrap(); + let send_stream_id = conn_c.stream_create(StreamType::UniDi).unwrap(); // create a decoder let mut decoder = QPackDecoder::new(300, 100); decoder.add_send_stream(send_stream_id); - TestDecoder { - decoder, - send_stream_id, - recv_stream_id, - conn, - peer_conn, - } - } - - fn recv_instruction(decoder: &mut TestDecoder, encoder_instruction: &[u8], res: Res<()>) { - let _ = decoder - .peer_conn - .stream_send(decoder.recv_stream_id, encoder_instruction); - let out = decoder.peer_conn.process(None, now()); - decoder.conn.process(out.dgram(), now()); - assert_eq!( - decoder - .decoder - .read_instructions(&mut decoder.conn, decoder.recv_stream_id), - res - ); - } - - fn send_instructions_and_check(decoder: &mut TestDecoder, decoder_instruction: &[u8]) { - decoder.decoder.send(&mut decoder.conn).unwrap(); - let out = decoder.conn.process(None, now()); - decoder.peer_conn.process(out.dgram(), now()); - let mut buf = [0u8; 100]; - let (amount, fin) = decoder - .peer_conn - .stream_recv(decoder.send_stream_id, &mut buf) - .unwrap(); - assert_eq!(fin, false); - assert_eq!(&buf[..amount], decoder_instruction); - } - - fn decode_headers( - decoder: &mut TestDecoder, - header_block: &[u8], - headers: &[Header], - stream_id: u64, - ) { - let decoded_headers = decoder - .decoder - .decode_header_block(header_block, stream_id) - .unwrap(); - let h = decoded_headers.unwrap(); - assert_eq!(h, headers); + (decoder, conn_c, conn_s, recv_stream_id, send_stream_id) } fn test_instruction( capacity: u64, instruction: &[u8], - res: Res<()>, + err: Option, decoder_instruction: &[u8], check_capacity: u64, ) { - let mut decoder = connect(); + let (mut decoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(); if capacity > 0 { - assert!(decoder.decoder.set_capacity(capacity).is_ok()); + assert!(decoder.set_capacity(capacity).is_ok()); + } + // send an instruction + let _ = conn_s.stream_send(recv_stream_id, instruction); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + + let res = decoder.read_instructions(&mut conn_c, recv_stream_id); + assert_eq!(err.is_some(), res.is_err()); + if let Some(expected_err) = err { + assert_eq!(expected_err, res.unwrap_err()); } - // recv an instruction - recv_instruction(&mut decoder, instruction, res); - - // send decoder instruction and check that is what we expect. - send_instructions_and_check(&mut decoder, decoder_instruction); + decoder.send(&mut conn_c).unwrap(); + let out = conn_c.process(None, now()); + conn_s.process(out.dgram(), now()); + let mut found_instruction = false; + while let Some(e) = conn_s.next_event() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + let mut buf = [0u8; 100]; + let (amount, fin) = conn_s.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(fin, false); + assert_eq!(buf[..amount], decoder_instruction[..]); + found_instruction = true; + } + } + assert_eq!(found_instruction, !decoder_instruction.is_empty()); if check_capacity > 0 { - assert_eq!(decoder.decoder.capacity(), check_capacity); + assert_eq!(decoder.capacity(), check_capacity); } } @@ -893,7 +857,7 @@ mod tests { test_instruction( 0, &[0xc4, 0x04, 0x31, 0x32, 0x33, 0x34], - Err(Error::DecoderStreamError), + Some(Error::DecoderStreamError), &[0x03], 0, ); @@ -905,7 +869,7 @@ mod tests { test_instruction( 100, &[0xc4, 0x04, 0x31, 0x32, 0x33, 0x34], - Ok(()), + None, &[0x03, 0x01], 0, ); @@ -920,7 +884,7 @@ mod tests { 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, ], - Ok(()), + None, &[0x03, 0x01], 0, ); @@ -928,7 +892,7 @@ mod tests { #[test] fn test_recv_change_capacity() { - test_instruction(0, &[0x3f, 0xa9, 0x01], Ok(()), &[0x03], 200); + test_instruction(0, &[0x3f, 0xa9, 0x01], None, &[0x03], 200); } #[test] @@ -936,7 +900,7 @@ mod tests { test_instruction( 0, &[0x3f, 0xf1, 0x02], - Err(Error::EncoderStreamError), + Some(Error::EncoderStreamError), &[0x03], 0, ); @@ -945,24 +909,49 @@ mod tests { // this test tests header decoding, the header acks command and the insert count increment command. #[test] fn test_duplicate() { - let mut decoder = connect(); + let (mut decoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(); - assert!(decoder.decoder.set_capacity(100).is_ok()); + assert!(decoder.set_capacity(100).is_ok()); - // receive an instruction - recv_instruction( - &mut decoder, + // send an instruction + let _ = conn_s.stream_send( + recv_stream_id, &[ 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, ], - Ok(()), ); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + assert!(decoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); - // receive the second instruction, a duplicate instruction. - recv_instruction(&mut decoder, &[0x00], Ok(())); + // send the second instruction, a duplicate instruction. + let _ = conn_s.stream_send(recv_stream_id, &[0x00]); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + if decoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_err() + { + panic!("failed to read") + } - send_instructions_and_check(&mut decoder, &[0x03, 0x02]); + decoder.send(&mut conn_c).unwrap(); + let out = conn_c.process(None, now()); + conn_s.process(out.dgram(), now()); + let mut found_instruction = false; + while let Some(e) = conn_s.next_event() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + let mut buf = [0u8; 100]; + let (amount, fin) = conn_s.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(fin, false); + assert_eq!(buf[..amount], [0x03, 0x02]); + found_instruction = true; + } + } + assert!(found_instruction); } struct TestElement { @@ -971,129 +960,6 @@ mod tests { pub encoder_inst: &'static [u8], } - #[test] - fn test_encode_incr_encode_header_ack_some() { - // 1. Decoder receives an instruction (header and value both as literal) - // 2. Decoder process the instruction and sends an increment instruction. - // 3. Decoder receives another two instruction (header and value both as literal) and - // a header block. - // 4. Now it sends only a header ack and an increment instruction with increment==1. - let headers = vec![ - (String::from("my-headera"), String::from("my-valuea")), - (String::from("my-headerb"), String::from("my-valueb")), - ]; - let header_block = &[0x03, 0x81, 0x10, 0x11]; - let first_encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x61, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x61, - ]; - let second_encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x62, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x62, 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x63, 0x09, 0x6d, 0x79, 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x63, - ]; - - let mut decoder = connect(); - - assert!(decoder.decoder.set_capacity(200).is_ok()); - - recv_instruction(&mut decoder, first_encoder_inst, Ok(())); - - send_instructions_and_check(&mut decoder, &[0x03, 0x1]); - - recv_instruction(&mut decoder, second_encoder_inst, Ok(())); - - decode_headers(&mut decoder, header_block, &headers, 0); - - send_instructions_and_check(&mut decoder, &[0x80, 0x1]); - } - - #[test] - fn test_encode_incr_encode_header_ack_all() { - // 1. Decoder receives an instruction (header and value both as literal) - // 2. Decoder process the instruction and sends an increment instruction. - // 3. Decoder receives another instruction (header and value both as literal) and - // a header block. - // 4. Now it sends only a header ack. - let headers = vec![ - (String::from("my-headera"), String::from("my-valuea")), - (String::from("my-headerb"), String::from("my-valueb")), - ]; - let header_block = &[0x03, 0x81, 0x10, 0x11]; - let first_encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x61, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x61, - ]; - let second_encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x62, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x62, - ]; - - let mut decoder = connect(); - - assert!(decoder.decoder.set_capacity(200).is_ok()); - - recv_instruction(&mut decoder, first_encoder_inst, Ok(())); - - send_instructions_and_check(&mut decoder, &[0x03, 0x1]); - - recv_instruction(&mut decoder, second_encoder_inst, Ok(())); - - decode_headers(&mut decoder, header_block, &headers, 0); - - send_instructions_and_check(&mut decoder, &[0x80]); - } - - #[test] - fn test_header_ack_all() { - // Send two instructions to insert values into the dynamic table and then send a header - // that references them both. The result should be only a header acknowledgement. - let headers = vec![ - (String::from("my-headera"), String::from("my-valuea")), - (String::from("my-headerb"), String::from("my-valueb")), - ]; - let header_block = &[0x03, 0x81, 0x10, 0x11]; - let encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x61, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x61, 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x62, 0x09, 0x6d, 0x79, 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x62, - ]; - - let mut decoder = connect(); - - assert!(decoder.decoder.set_capacity(200).is_ok()); - - recv_instruction(&mut decoder, encoder_inst, Ok(())); - - decode_headers(&mut decoder, header_block, &headers, 0); - - send_instructions_and_check(&mut decoder, &[0x03, 0x80]); - } - - #[test] - fn test_header_ack_and_incr_instruction() { - // Send two instructions to insert values into the dynamic table and then send a header - // that references only the first. The result should be a header acknowledgement and a - // increment instruction. - let headers = vec![(String::from("my-headera"), String::from("my-valuea"))]; - let header_block = &[0x02, 0x80, 0x10]; - let encoder_inst = &[ - 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x61, 0x09, 0x6d, 0x79, - 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x61, 0x4a, 0x6d, 0x79, 0x2d, 0x68, 0x65, 0x61, - 0x64, 0x65, 0x72, 0x62, 0x09, 0x6d, 0x79, 0x2d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x62, - ]; - - let mut decoder = connect(); - - assert!(decoder.decoder.set_capacity(200).is_ok()); - - recv_instruction(&mut decoder, encoder_inst, Ok(())); - - decode_headers(&mut decoder, header_block, &headers, 0); - - send_instructions_and_check(&mut decoder, &[0x03, 0x80, 0x01]); - } - #[test] fn test_header_block_decoder() { let test_cases: [TestElement; 6] = [ @@ -1152,26 +1018,42 @@ mod tests { }, ]; - let mut decoder = connect(); - - assert!(decoder.decoder.set_capacity(200).is_ok()); + let (mut decoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(); + assert!(decoder.set_capacity(200).is_ok()); for (i, t) in test_cases.iter().enumerate() { - // receive an instruction + // send an instruction if !t.encoder_inst.is_empty() { - recv_instruction(&mut decoder, t.encoder_inst, Ok(())); + let _ = conn_s.stream_send(recv_stream_id, t.encoder_inst); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + assert!(decoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); } - decode_headers( - &mut decoder, - t.header_block, - &t.headers, - i.try_into().unwrap(), - ); + let headers = decoder + .decode_header_block(t.header_block, i.try_into().unwrap()) + .unwrap(); + let h = headers.unwrap(); + assert_eq!(h, t.headers); } // test header acks and the insert count increment command - send_instructions_and_check(&mut decoder, &[0x03, 0x82, 0x83, 0x84]); + decoder.send(&mut conn_c).unwrap(); + let out = conn_c.process(None, now()); + conn_s.process(out.dgram(), now()); + let mut found_instruction = false; + while let Some(e) = conn_s.next_event() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + let mut buf = [0u8; 100]; + let (amount, fin) = conn_s.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(fin, false); + assert_eq!(buf[..amount], [0x03, 0x82, 0x83, 0x84, 0x1]); + found_instruction = true; + } + } + assert!(found_instruction); } #[test] @@ -1230,25 +1112,42 @@ mod tests { }, ]; - let mut decoder = connect(); + let (mut decoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(); - assert!(decoder.decoder.set_capacity(200).is_ok()); + assert!(decoder.set_capacity(200).is_ok()); for (i, t) in test_cases.iter().enumerate() { - // receive an instruction. + // send an instruction. if !t.encoder_inst.is_empty() { - recv_instruction(&mut decoder, t.encoder_inst, Ok(())); + let _ = conn_s.stream_send(recv_stream_id, t.encoder_inst); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + // read the instruction. + assert!(decoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); } - decode_headers( - &mut decoder, - t.header_block, - &t.headers, - i.try_into().unwrap(), - ); + let headers = decoder + .decode_header_block(t.header_block, i.try_into().unwrap()) + .unwrap(); + assert_eq!(headers.unwrap(), t.headers); } // test header acks and the insert count increment command - send_instructions_and_check(&mut decoder, &[0x03, 0x82, 0x83, 0x84]); + decoder.send(&mut conn_c).unwrap(); + let out = conn_c.process(None, now()); + conn_s.process(out.dgram(), now()); + let mut found_instruction = false; + while let Some(e) = conn_s.next_event() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + let mut buf = [0u8; 100]; + let (amount, fin) = conn_s.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(fin, false); + assert_eq!(buf[..amount], [0x03, 0x82, 0x83, 0x84, 0x1]); + found_instruction = true; + } + } + assert!(found_instruction); } } diff --git a/third_party/rust/neqo-qpack/src/encoder.rs b/third_party/rust/neqo-qpack/src/encoder.rs index 5a0d1e18d309..e49d477d5306 100644 --- a/third_party/rust/neqo-qpack/src/encoder.rs +++ b/third_party/rust/neqo-qpack/src/encoder.rs @@ -4,16 +4,16 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. +#![allow(unused_variables, dead_code)] + use crate::huffman::encode_huffman; use crate::qpack_helper::read_prefixed_encoded_int_with_connection; use crate::qpack_send_buf::QPData; -use crate::table::{HeaderTable, LookupResult}; +use crate::table::HeaderTable; use crate::Header; use crate::{Error, Res}; use neqo_common::{qdebug, qtrace}; use neqo_transport::Connection; -use std::collections::{HashMap, HashSet, VecDeque}; -use std::convert::TryInto; pub const QPACK_UNI_STREAM_TYPE_ENCODER: u64 = 0x2; @@ -40,23 +40,19 @@ pub struct QPackEncoder { send_buf: QPData, max_entries: u64, instruction_reader_current_inst: Option, - instruction_reader_value: u64, // this is instruction dependent value. + instruction_reader_value: u64, // this is instrunction dependent value. instruction_reader_cnt: u8, // this is helper variable for reading a prefixed integer encoded value local_stream_id: Option, remote_stream_id: Option, max_blocked_streams: u16, - // Remember header blocks that are referring to dynamic table. - // There can be multiple header blocks in one stream, headers, trailer, push stream request, etc. - // This HashMap maps a stream ID to a list of header blocks. Each header block is a list of - // referenced dynamic table entries. - unacked_header_blocks: HashMap>>, - blocked_stream_cnt: u16, + blocked_streams: Vec, // remember request insert counds for blocked streams. + // TODO we may also remember stream_id and use stream acks as indication that a stream has beed unblocked. use_huffman: bool, } impl QPackEncoder { - pub fn new(use_huffman: bool) -> Self { - Self { + pub fn new(use_huffman: bool) -> QPackEncoder { + QPackEncoder { table: HeaderTable::new(true), send_buf: QPData::default(), max_entries: 0, @@ -66,27 +62,29 @@ impl QPackEncoder { local_stream_id: None, remote_stream_id: None, max_blocked_streams: 0, - unacked_header_blocks: HashMap::new(), - blocked_stream_cnt: 0, + blocked_streams: Vec::new(), use_huffman, } } pub fn set_max_capacity(&mut self, cap: u64) -> Res<()> { if cap > (1 << 30) - 1 { - // TODO dragana check what is the correct error. + // TODO dragana check wat is the correct error. return Err(Error::EncoderStreamError); } qdebug!([self], "Set max capacity to {}.", cap); self.max_entries = (cap as f64 / 32.0).floor() as u64; // we also set our table to the max allowed. TODO we may not want to use max allowed. - self.change_capacity(cap) + self.change_capacity(cap); + Ok(()) } pub fn set_max_blocked_streams(&mut self, blocked_streams: u64) -> Res<()> { - self.max_blocked_streams = blocked_streams - .try_into() - .or(Err(Error::EncoderStreamError))?; + if blocked_streams > (1 << 16) - 1 { + return Err(Error::EncoderStreamError); + } + qdebug!([self], "Set max blocked streams to {}.", blocked_streams); + self.max_blocked_streams = blocked_streams as u16; Ok(()) } @@ -138,7 +136,7 @@ impl QPackEncoder { ) { Ok(done) => { if done { - self.call_instruction()?; + self.call_instruction(); } else { // wait for more data. break Ok(()); @@ -160,7 +158,7 @@ impl QPackEncoder { ) { Ok(done) => { if done { - self.call_instruction()?; + self.call_instruction(); } else { // wait for more data. break Ok(()); @@ -174,95 +172,35 @@ impl QPackEncoder { } } - fn recalculate_blocked_streams(&mut self) { - let acked_inserts_cnt = self.table.get_acked_inserts_cnt(); - self.blocked_stream_cnt = 0; - for (_, hb_list) in self.unacked_header_blocks.iter_mut() { - debug_assert!(!hb_list.is_empty()); - if hb_list - .iter() - .flat_map(|hb| hb.iter()) - .any(|e| *e >= acked_inserts_cnt) - { - self.blocked_stream_cnt += 1; - } - } - } - - fn insert_count_instruction(&mut self, increment: u64) -> Res<()> { - self.table.increment_acked(increment)?; - self.recalculate_blocked_streams(); - Ok(()) - } - - fn header_ack(&mut self, stream_id: u64) -> Res<()> { - let mut new_acked = self.table.get_acked_inserts_cnt(); - if let Some(hb_list) = self.unacked_header_blocks.get_mut(&stream_id) { - if let Some(ref_list) = hb_list.pop_back() { - for iter in ref_list { - self.table.remove_ref(iter); - if iter >= new_acked { - new_acked = iter + 1; - } - } - } else { - debug_assert!(false, "We should have at least one header block."); - } - if hb_list.is_empty() { - self.unacked_header_blocks.remove(&stream_id); - } - } - if new_acked > self.table.get_acked_inserts_cnt() { - self.insert_count_instruction(new_acked - self.table.get_acked_inserts_cnt()) - .expect("This should neve happen"); - } - Ok(()) - } - - fn stream_cancellation(&mut self, stream_id: u64) -> Res<()> { - let mut was_blocker = false; - if let Some(hb_list) = self.unacked_header_blocks.get_mut(&stream_id) { - debug_assert!(!hb_list.is_empty()); - while let Some(ref_list) = hb_list.pop_front() { - for iter in ref_list { - self.table.remove_ref(iter); - was_blocker = was_blocker || (iter >= self.table.get_acked_inserts_cnt()); - } - } - } - if was_blocker { - debug_assert!(self.blocked_stream_cnt > 0); - self.blocked_stream_cnt -= 1; - } - Ok(()) - } - - fn call_instruction(&mut self) -> Res<()> { + fn call_instruction(&mut self) { if let Some(inst) = &self.instruction_reader_current_inst { qdebug!([self], "call intruction {:?}", inst); match inst { DecoderInstructions::InsertCountIncrement => { - self.insert_count_instruction(self.instruction_reader_value)? + self.table.increment_acked(self.instruction_reader_value); + let inserts = self.table.get_acked_inserts_cnt(); + self.blocked_streams.retain(|req| *req <= inserts); + } + DecoderInstructions::HeaderAck => { + self.table.header_ack(self.instruction_reader_value) } - DecoderInstructions::HeaderAck => self.header_ack(self.instruction_reader_value)?, DecoderInstructions::StreamCancellation => { - self.stream_cancellation(self.instruction_reader_value)? + self.table.header_ack(self.instruction_reader_value) } - }; + } self.instruction_reader_current_inst = None; self.instruction_reader_value = 0; self.instruction_reader_cnt = 0; } else { panic!("We must have a instruction decoded beforewe call call_instruction"); } - Ok(()) } pub fn insert_with_name_ref( &mut self, name_static_table: bool, name_index: u64, - value: &[u8], + value: Vec, ) -> Res<()> { qdebug!( [self], @@ -279,23 +217,25 @@ impl QPackEncoder { .insert_with_name_ref(name_static_table, name_index, value)?; // write instruction + let entry = self.table.get_last_added_entry().unwrap(); let instr = 0x80 | (if name_static_table { 0x40 } else { 0x00 }); self.send_buf .encode_prefixed_encoded_int(instr, 2, name_index); - encode_literal(self.use_huffman, &mut self.send_buf, 0x0, 0, value); + encode_literal(self.use_huffman, &mut self.send_buf, 0x0, 0, entry.value()); Ok(()) } - pub fn insert_with_name_literal(&mut self, name: &[u8], value: &[u8]) -> Res { + pub fn insert_with_name_literal(&mut self, name: Vec, value: Vec) -> Res<()> { qdebug!([self], "insert name {:x?}, value={:x?}.", name, value); // try to insert a new entry - let index = self.table.insert(name, value)?; + self.table.insert(name, value)?; + let entry = self.table.get_last_added_entry().unwrap(); // encode instruction. - encode_literal(self.use_huffman, &mut self.send_buf, 0x40, 2, name); - encode_literal(self.use_huffman, &mut self.send_buf, 0x0, 0, value); + encode_literal(self.use_huffman, &mut self.send_buf, 0x40, 2, entry.name()); + encode_literal(self.use_huffman, &mut self.send_buf, 0x0, 0, entry.value()); - Ok(index) + Ok(()) } pub fn duplicate(&mut self, index: u64) -> Res<()> { @@ -305,11 +245,10 @@ impl QPackEncoder { Ok(()) } - pub fn change_capacity(&mut self, cap: u64) -> Res<()> { + pub fn change_capacity(&mut self, cap: u64) { qdebug!([self], "change capacity: {}", cap); - self.table.set_capacity(cap)?; + self.table.set_capacity(cap); self.send_buf.encode_prefixed_encoded_int(0x20, 3, cap); - Ok(()) } pub fn send(&mut self, conn: &mut Connection) -> Res<()> { @@ -329,59 +268,67 @@ impl QPackEncoder { } } - fn is_stream_blocker(&self, stream_id: u64) -> bool { - if let Some(hb_list) = self.unacked_header_blocks.get(&stream_id) { - debug_assert!(!hb_list.is_empty()); - match hb_list.iter().flat_map(|hb| hb.iter()).max() { - Some(max_ref) => *max_ref >= self.table.get_acked_inserts_cnt(), - None => false, - } - } else { - false - } - } - pub fn encode_header_block(&mut self, h: &[Header], stream_id: u64) -> QPData { qdebug!([self], "encoding headers."); let mut encoded_h = QPData::default(); let base = self.table.base(); + let mut req_insert_cnt = 0; self.encode_header_block_prefix(&mut encoded_h, false, 0, base, true); - - let stream_is_blocker = self.is_stream_blocker(stream_id); - let can_block = self.blocked_stream_cnt < self.max_blocked_streams || stream_is_blocker; - - let mut ref_entries = HashSet::new(); - for iter in h.iter() { let name = iter.0.clone().into_bytes(); let value = iter.1.clone().into_bytes(); qtrace!("encoding {:x?} {:x?}.", name, value); - if let Some(LookupResult { - index, - static_table, - value_matches, - }) = self.table.lookup(&name, &value, can_block) + let mut can_use = false; + let mut index: u64 = 0; + let mut value_as_well = false; + let mut is_dynamic = false; + let acked_inserts_cnt = self.table.get_acked_inserts_cnt(); // we need to read it here because of borrowing problem. + let can_be_blocked = self.blocked_streams.len() < self.max_blocked_streams as usize; { - qtrace!( - [self], - "found a {} entry, value-match={}", - if static_table { "static" } else { "dynamic" }, - value_matches - ); - if static_table { - if value_matches { - self.encode_indexed(&mut encoded_h, true, index); - } else { - self.encode_literal_with_name_ref(&mut encoded_h, true, index, &value); - } - } else { - if value_matches { - if index < base { - self.encode_indexed(&mut encoded_h, false, base - index - 1); - } else { - self.encode_post_base_index(&mut encoded_h, index - base); + let label = self.to_string(); + // this is done in this way because otherwise it is complaining about mut borrow. TODO: look if we can do this better + let (e_s, e_d, found_value) = self.table.lookup(&name, &value); + if let Some(entry) = e_s { + qtrace!([label], "found a static entry, value-match={}", found_value); + can_use = true; + index = entry.index(); + value_as_well = found_value; + } + if !can_use { + if let Some(entry) = e_d { + index = entry.index(); + can_use = index < acked_inserts_cnt || can_be_blocked; + qtrace!( + [label], + "found a dynamic entry - can_use={} value-match={},", + can_use, + found_value + ); + if can_use { + value_as_well = found_value; + is_dynamic = true; + entry.add_ref(stream_id, 1); } + } + } + } + if can_use { + if value_as_well { + if !is_dynamic { + self.encode_indexed(&mut encoded_h, true, index); + } else if index < base { + self.encode_indexed(&mut encoded_h, false, base - index - 1); + } else { + self.encode_post_base_index(&mut encoded_h, index - base); + } + if is_dynamic && req_insert_cnt < (index + 1) { + req_insert_cnt = index + 1; + } + continue; + } else { + if !is_dynamic { + self.encode_literal_with_name_ref(&mut encoded_h, true, index, &value); } else if index < base { self.encode_literal_with_name_ref( &mut encoded_h, @@ -396,40 +343,35 @@ impl QPackEncoder { &value, ); } - ref_entries.insert(index); + + if is_dynamic && req_insert_cnt < (index + 1) { + req_insert_cnt = index + 1; + } + continue; } - } else if !can_block { - self.encode_literal_with_name_literal(&mut encoded_h, &name, &value); - } else { - match self.insert_with_name_literal(&name, &value) { - Ok(index) => { - self.encode_post_base_index(&mut encoded_h, index - base); - ref_entries.insert(index); - } - Err(_) => { - self.encode_literal_with_name_literal(&mut encoded_h, &name, &value); + } + + let name2 = name.clone(); + let value2 = value.clone(); + match self.insert_with_name_literal(name2, value2) { + Err(_) => { + self.encode_literal_with_name_literal(&mut encoded_h, &name, &value); + } + Ok(()) => { + let index: u64; + { + let entry = self.table.get_last_added_entry().unwrap(); + entry.add_ref(stream_id, 1); + index = entry.index(); } + self.encode_post_base_index(&mut encoded_h, index - base); + + req_insert_cnt = index + 1; } } } - for iter in &ref_entries { - self.table.add_ref(*iter); - } - - if let Some(max_ref) = ref_entries.iter().max() { - self.fix_header_block_prefix(&mut encoded_h, base, *max_ref + 1); - // Check if it is already blocking - if !stream_is_blocker && *max_ref >= self.table.get_acked_inserts_cnt() { - debug_assert!(self.blocked_stream_cnt < self.max_blocked_streams); - self.blocked_stream_cnt += 1; - } - } - - if !ref_entries.is_empty() { - self.unacked_header_blocks - .entry(stream_id) - .or_insert_with(VecDeque::new) - .push_front(ref_entries); + if req_insert_cnt > 0 { + self.fix_header_block_prefix(&mut encoded_h, base, req_insert_cnt); } encoded_h } @@ -562,11 +504,6 @@ impl QPackEncoder { } } } - - #[cfg(test)] - pub fn blocked_stream_cnt(&self) -> u16 { - self.blocked_stream_cnt - } } fn encode_literal(use_huffman: bool, buf: &mut QPData, prefix: u8, prefix_len: u8, value: &[u8]) { @@ -593,173 +530,213 @@ impl ::std::fmt::Display for QPackEncoder { #[cfg(test)] mod tests { use super::*; + use neqo_transport::ConnectionEvent; use neqo_transport::StreamType; use test_fixture::*; - struct TestEncoder { - encoder: QPackEncoder, - send_stream_id: u64, - recv_stream_id: u64, - conn: Connection, - peer_conn: Connection, - } - - fn connect(huffman: bool) -> TestEncoder { - let (mut conn, mut peer_conn) = test_fixture::connect(); + fn connect(huffman: bool) -> (QPackEncoder, Connection, Connection, u64, u64) { + let (mut conn_c, mut conn_s) = test_fixture::connect(); // create a stream - let recv_stream_id = peer_conn.stream_create(StreamType::UniDi).unwrap(); - let send_stream_id = conn.stream_create(StreamType::UniDi).unwrap(); + let recv_stream_id = conn_s.stream_create(StreamType::UniDi).unwrap(); + let send_stream_id = conn_c.stream_create(StreamType::UniDi).unwrap(); // create an encoder let mut encoder = QPackEncoder::new(huffman); encoder.add_send_stream(send_stream_id); - TestEncoder { - encoder, - send_stream_id, - recv_stream_id, - conn, - peer_conn, + (encoder, conn_c, conn_s, recv_stream_id, send_stream_id) + } + + fn test_sent_instructions( + encoder: &mut QPackEncoder, + mut conn_c: &mut Connection, + conn_s: &mut Connection, + recv_stream_id: u64, + send_stream_id: u64, + encoder_instruction: &[u8], + ) { + encoder.send(&mut conn_c).unwrap(); + let out = conn_c.process(None, now()); + conn_s.process(out.dgram(), now()); + let mut found_instruction = false; + while let Some(e) = conn_s.next_event() { + if let ConnectionEvent::RecvStreamReadable { stream_id } = e { + let mut buf = [0u8; 100]; + let (amount, fin) = conn_s.stream_recv(stream_id, &mut buf).unwrap(); + assert_eq!(fin, false); + assert_eq!(buf[..amount], encoder_instruction[..]); + found_instruction = true; + } } + assert_eq!(found_instruction, !encoder_instruction.is_empty()); } - fn send_instructions(encoder: &mut TestEncoder, encoder_instruction: &[u8]) { - encoder.encoder.send(&mut encoder.conn).unwrap(); - let out = encoder.conn.process(None, now()); - encoder.peer_conn.process(out.dgram(), now()); - let mut buf = [0u8; 100]; - let (amount, fin) = encoder - .peer_conn - .stream_recv(encoder.send_stream_id, &mut buf) - .unwrap(); - assert_eq!(fin, false); - assert_eq!(buf[..amount], encoder_instruction[..]); - } - - fn recv_instruction(encoder: &mut TestEncoder, decoder_instruction: &[u8]) { - encoder - .peer_conn - .stream_send(encoder.recv_stream_id, decoder_instruction) - .unwrap(); - let out = encoder.peer_conn.process(None, now()); - encoder.conn.process(out.dgram(), now()); - assert!(encoder - .encoder - .read_instructions(&mut encoder.conn, encoder.recv_stream_id) - .is_ok()); - } - - const CAP_INSTRUCTION_200: &[u8] = &[0x02, 0x3f, 0xa9, 0x01]; - const CAP_INSTRUCTION_60: &[u8] = &[0x02, 0x3f, 0x1d]; - - const HEADER_CONTENT_LENGTH: &[u8] = &[ - 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, - ]; - const VALUE_1: &[u8] = &[0x31, 0x32, 0x33, 0x34]; - const VALUE_2: &[u8] = &[0x31, 0x32, 0x33, 0x34, 0x35]; - - // HEADER_CONTENT_LENGTH and VALUE_1 encoded by instruction insert_with_name_literal. - const HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL: &[u8] = &[ - 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, - 0x04, 0x31, 0x32, 0x33, 0x34, - ]; - - // HEADER_CONTENT_LENGTH and VALUE_2 encoded by instruction insert_with_name_literal. - const HEADER_CONTENT_LENGTH_VALUE_2_NAME_LITERAL: &[u8] = &[ - 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, - 0x05, 0x31, 0x32, 0x33, 0x34, 0x35, - ]; - - // Indexed Header Field that refers to the first entry in the dynamic table. - const ENCODE_INDEXED_REF_DYNAMIC: &[u8] = &[0x02, 0x00, 0x80]; - - const HEADER_ACK_STREAM_ID_1: &[u8] = &[0x81]; - const HEADER_ACK_STREAM_ID_2: &[u8] = &[0x82]; - const STREAM_CANCELED_ID_1: &[u8] = &[0x41]; - // test insert_with_name_ref which fails because there is not enough space in the table #[test] fn test_insert_with_name_ref_1() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); let e = encoder - .encoder - .insert_with_name_ref(true, 4, VALUE_1) + .insert_with_name_ref(true, 4, vec![0x31, 0x32, 0x33, 0x34]) .unwrap_err(); assert_eq!(Error::EncoderStreamError, e); - send_instructions(&mut encoder, &[0x02]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02], + ); } // test insert_name_ref that succeeds #[test] fn test_insert_with_name_ref_2() { - let mut encoder = connect(false); - assert!(encoder.encoder.set_max_capacity(200).is_ok()); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); + assert!(encoder.set_max_capacity(200).is_ok()); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); assert!(encoder - .encoder - .insert_with_name_ref(true, 4, VALUE_1) + .insert_with_name_ref(true, 4, vec![0x31, 0x32, 0x33, 0x34]) .is_ok()); - send_instructions(&mut encoder, &[0xc4, 0x04, 0x31, 0x32, 0x33, 0x34]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0xc4, 0x04, 0x31, 0x32, 0x33, 0x34], + ); } // test insert_with_name_literal which fails because there is not enough space in the table #[test] fn test_insert_with_name_literal_1() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34], + ); assert_eq!(Error::EncoderStreamError, res.unwrap_err()); - send_instructions(&mut encoder, &[0x02]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02], + ); } // test insert_with_name_literal - succeeds #[test] fn test_insert_with_name_literal_2() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - assert!(encoder.encoder.set_max_capacity(200).is_ok()); + assert!(encoder.set_max_capacity(200).is_ok()); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, + ], + ); } #[test] fn test_change_capacity() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + assert!(encoder.set_max_capacity(200).is_ok()); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); } #[test] fn test_duplicate() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - assert!(encoder.encoder.set_max_capacity(200).is_ok()); + assert!(encoder.set_max_capacity(200).is_ok()); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, + ], + ); - assert!(encoder.encoder.duplicate(0).is_ok()); - send_instructions(&mut encoder, &[0x00]); + assert!(encoder.duplicate(0).is_ok()); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x00], + ); } struct TestElement { @@ -798,7 +775,7 @@ mod tests { // test encode_indexed with a ref to dynamic table. TestElement { headers: vec![(String::from("my-header"), String::from("my-value"))], - header_block: ENCODE_INDEXED_REF_DYNAMIC, + header_block: &[0x02, 0x00, 0x80], encoder_inst: &[], }, // test encode_literal_with_name_ref. @@ -826,18 +803,32 @@ mod tests { }, ]; - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - encoder.encoder.set_max_blocked_streams(100).unwrap(); - encoder.encoder.set_max_capacity(200).unwrap(); + encoder.set_max_blocked_streams(100).unwrap(); + encoder.set_max_capacity(200).unwrap(); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); for t in &test_cases { - let buf = encoder.encoder.encode_header_block(&t.headers, 1); + let buf = encoder.encode_header_block(&t.headers, 1); assert_eq!(&buf[..], t.header_block); - send_instructions(&mut encoder, t.encoder_inst); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + t.encoder_inst, + ); } } @@ -870,7 +861,7 @@ mod tests { // test encode_indexed with a ref to dynamic table. TestElement { headers: vec![(String::from("my-header"), String::from("my-value"))], - header_block: ENCODE_INDEXED_REF_DYNAMIC, + header_block: &[0x02, 0x00, 0x80], encoder_inst: &[], }, // test encode_literal_with_name_ref. @@ -897,600 +888,237 @@ mod tests { }, ]; - let mut encoder = connect(true); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(true); - encoder.encoder.set_max_blocked_streams(100).unwrap(); - encoder.encoder.set_max_capacity(200).unwrap(); + encoder.set_max_blocked_streams(100).unwrap(); + encoder.set_max_capacity(200).unwrap(); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0xa9, 0x01], + ); for t in &test_cases { - let buf = encoder.encoder.encode_header_block(&t.headers, 1); + let buf = encoder.encode_header_block(&t.headers, 1); assert_eq!(&buf[..], t.header_block); - send_instructions(&mut encoder, t.encoder_inst); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + t.encoder_inst, + ); } } // Test inserts block on waiting for an insert count increment. #[test] fn test_insertion_blocked_on_insert_count_feedback() { - let mut encoder = connect(false); + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - encoder.encoder.set_max_capacity(60).unwrap(); + encoder.set_max_capacity(60).unwrap(); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0x1d], + ); // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, + ], + ); // insert "content-length: 12345 which will fail because the ntry in the table cannot be evicted. - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_2); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34, 0x35], + ); assert!(res.is_err()); - send_instructions(&mut encoder, &[]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[], + ); // receive an insert count increment. - recv_instruction(&mut encoder, &[0x01]); + conn_s.stream_send(recv_stream_id, &[0x01]).unwrap(); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + assert!(encoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); // insert "content-length: 12345 again it will succeed. - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_2); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34, 0x35], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_2_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x05, 0x31, 0x32, 0x33, 0x34, 0x35, + ], + ); } // Test inserts block on waiting for acks - // test the table insertion is blocked: - // 0 - waiting for a header ack - // 2 - waiting for a stream cancel. - fn test_insertion_blocked_on_waiting_for_header_ack_or_stream_cancel(wait: u8) { - let mut encoder = connect(false); + // test the table inseriong blocking: + // 0 - waithing for a header ack + // 2 - waithing for a stream cancel. + fn test_insertion_blocked_on_waiting_forheader_ack_or_stream_cancel(wait: u8) { + let (mut encoder, mut conn_c, mut conn_s, recv_stream_id, send_stream_id) = connect(false); - assert!(encoder.encoder.set_max_capacity(60).is_ok()); + assert!(encoder.set_max_capacity(60).is_ok()); // test the change capacity instruction. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[0x02, 0x3f, 0x1d], + ); // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x04, 0x31, 0x32, 0x33, 0x34, + ], + ); // receive an insert count increment. - recv_instruction(&mut encoder, &[0x01]); + let _ = conn_s.stream_send(recv_stream_id, &[0x01]); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); + assert!(encoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); // send a header block let buf = encoder - .encoder .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_eq!(&buf[..], ENCODE_INDEXED_REF_DYNAMIC); - send_instructions(&mut encoder, &[]); + assert_eq!(&buf[..], &[0x02, 0x00, 0x80]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[], + ); - // insert "content-length: 12345 which will fail because the entry in the table cannot be evicted - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_2); + // insert "content-length: 12345 which will fail because the entry in the table cannot be evicted. + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34, 0x35], + ); assert!(res.is_err()); - send_instructions(&mut encoder, &[]); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[], + ); if wait == 0 { // receive a header_ack. - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_1); + let _ = conn_s.stream_send(recv_stream_id, &[0x81]); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); } else { - // receive a stream canceled - recv_instruction(&mut encoder, STREAM_CANCELED_ID_1); + // reveice a stream canceled + let _ = conn_s.stream_send(recv_stream_id, &[0x41]); + let out = conn_s.process(None, now()); + conn_c.process(out.dgram(), now()); } + assert!(encoder + .read_instructions(&mut conn_c, recv_stream_id) + .is_ok()); // insert "content-length: 12345 again it will succeed. - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_2); + let res = encoder.insert_with_name_literal( + vec![ + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, + ], + vec![0x31, 0x32, 0x33, 0x34, 0x35], + ); assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_2_NAME_LITERAL); + test_sent_instructions( + &mut encoder, + &mut conn_c, + &mut conn_s, + recv_stream_id, + send_stream_id, + &[ + 0x4e, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, + 0x68, 0x05, 0x31, 0x32, 0x33, 0x34, 0x35, + ], + ); } #[test] fn test_header_ack() { - test_insertion_blocked_on_waiting_for_header_ack_or_stream_cancel(0); + test_insertion_blocked_on_waiting_forheader_ack_or_stream_cancel(0); } #[test] fn test_stream_canceled() { - test_insertion_blocked_on_waiting_for_header_ack_or_stream_cancel(1); - } - - fn assert_is_index_to_dynamic(buf: &[u8]) { - assert_eq!(buf[2] & 0xc0, 0x80); - } - - fn assert_is_index_to_dynamic_post(buf: &[u8]) { - assert_eq!(buf[2] & 0xf0, 0x10); - } - - fn assert_is_index_to_static_name_only(buf: &[u8]) { - assert_eq!(buf[2] & 0xf0, 0x50); - } - - fn assert_is_literal_value_literal_name(buf: &[u8]) { - assert_eq!(buf[2] & 0xf0, 0x20); - } - - #[test] - fn max_block_streams1() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(60).is_ok()); - - // change capacity to 60. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - encoder.encoder.set_max_blocked_streams(1).unwrap(); - - // send a header block, it refers to unacked entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_is_index_to_dynamic(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - send_instructions(&mut encoder, &[]); - - // The next one will not use the dynamic entry because it is exceeding the max_blocked_streams - // limit. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 2); - assert_is_index_to_static_name_only(&buf); - - send_instructions(&mut encoder, &[]); - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // another header block to already blocked stream can still use the entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_is_index_to_dynamic(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - } - - #[test] - fn max_block_streams2() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - // insert "content-length: 12345 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_2); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_2_NAME_LITERAL); - - encoder.encoder.set_max_blocked_streams(1).unwrap(); - - let stream_id = 1; - // send a header block, it refers to unacked entry. - let buf = encoder.encoder.encode_header_block( - &[(String::from("content-length"), String::from("1234"))], - stream_id, - ); - assert_is_index_to_dynamic(&buf); - - // encode another header block for the same stream that will refer to the second entry - // in the dynamic table. - // This should work because the stream is already a blocked stream - // send a header block, it refers to unacked entry. - let buf = encoder.encoder.encode_header_block( - &[(String::from("content-length"), String::from("12345"))], - stream_id, - ); - assert_is_index_to_dynamic(&buf); - } - - #[test] - fn max_block_streams3() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(1).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // The next one will not create a new entry because the encoder is on max_blocked_streams limit. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name2"), String::from("value2"))], 2); - assert_is_literal_value_literal_name(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // another header block to already blocked stream can still create a new entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name2"), String::from("value2"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - } - - #[test] - fn max_block_streams4() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(1).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // another header block to already blocked stream can still create a new entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name2"), String::from("value2"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // receive a header_ack for the first header block. - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_1); - - // The stream is still blocking because the second header block is not acked. - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - } - - #[test] - fn max_block_streams5() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(1).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // another header block to already blocked stream can still create a new entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // receive a header_ack for the first header block. - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_1); - - // The stream is not blocking anymore because header ack also acks the instruction. - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - } - - #[test] - fn max_block_streams6() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // header block for the next stream will create an new entry as well. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name2"), String::from("value2"))], 2); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 2); - - // receive a header_ack for the second header block. This will ack the first as well - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_2); - - // The stream is not blocking anymore because header ack also acks the instruction. - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - } - - #[test] - fn max_block_streams7() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // header block for the next stream will create an new entry as well. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 2); - assert_is_index_to_dynamic(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 2); - - // receive a stream cancel for the first stream. - // This will remove the first stream as blocking but it will not mark the instruction as acked. - // and the second steam will still be blocking. - recv_instruction(&mut encoder, STREAM_CANCELED_ID_1); - - // The stream is not blocking anymore because header ack also acks the instruction. - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - } - - #[test] - fn max_block_stream8() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(200).is_ok()); - - // change capacity to 200. - send_instructions(&mut encoder, CAP_INSTRUCTION_200); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 0); - - // send a header block, that creates an new entry and refers to it. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - - // header block for the next stream will refer to the same entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name1"), String::from("value1"))], 2); - assert_is_index_to_dynamic(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 2); - - // send another header block on stream 1. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("name2"), String::from("value2"))], 1); - assert_is_index_to_dynamic_post(&buf); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 2); - - // stream 1 is block on entries 1 and 2; stream 2 is block only on 1. - // receive an Insert Count Increment for the first entry. - // After that only stream 1 will be blocking. - recv_instruction(&mut encoder, &[0x01]); - - assert_eq!(encoder.encoder.blocked_stream_cnt(), 1); - } - - #[test] - fn dynamic_table_can_evict1() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(60).is_ok()); - - // change capacity to 60. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - // send a header block, it refers to unacked entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_is_index_to_dynamic(&buf); - - // trying to evict the entry will failed. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive an Insert Count Increment for the entry. - recv_instruction(&mut encoder, &[0x01]); - - // trying to evict the entry will failed. The stream is still referring to it. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive a header_ack for the header block. - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_1); - - // now entry can be evicted. - assert!(encoder.encoder.set_max_capacity(10).is_ok()); - } - - #[test] - fn dynamic_table_can_evict2() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(60).is_ok()); - - // change capacity to 60. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - // send a header block, it refers to unacked entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_is_index_to_dynamic(&buf); - - // trying to evict the entry will failed. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive an Insert Count Increment for the entry. - recv_instruction(&mut encoder, &[0x01]); - - // trying to evict the entry will failed. The stream is still referring to it. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive a stream cancelled. - recv_instruction(&mut encoder, STREAM_CANCELED_ID_1); - - // now entry can be evicted. - assert!(encoder.encoder.set_max_capacity(10).is_ok()); - } - - #[test] - fn dynamic_table_can_evict3() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(60).is_ok()); - - // change capacity to 60. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - // trying to evict the entry will failed, because the entry is not acked. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive an Insert Count Increment for the entry. - recv_instruction(&mut encoder, &[0x01]); - - // now entry can be evicted. - assert!(encoder.encoder.set_max_capacity(10).is_ok()); - } - - #[test] - fn dynamic_table_can_evict4() { - let mut encoder = connect(false); - - assert!(encoder.encoder.set_max_capacity(60).is_ok()); - - // change capacity to 60. - send_instructions(&mut encoder, CAP_INSTRUCTION_60); - - encoder.encoder.set_max_blocked_streams(2).unwrap(); - - // insert "content-length: 1234 - let res = encoder - .encoder - .insert_with_name_literal(HEADER_CONTENT_LENGTH, VALUE_1); - - assert!(res.is_ok()); - send_instructions(&mut encoder, HEADER_CONTENT_LENGTH_VALUE_1_NAME_LITERAL); - - // send a header block, it refers to unacked entry. - let buf = encoder - .encoder - .encode_header_block(&[(String::from("content-length"), String::from("1234"))], 1); - assert_is_index_to_dynamic(&buf); - - // trying to evict the entry will failed. The stream is still referring to it and - // entry is not acked. - assert!(encoder.encoder.set_max_capacity(10).is_err()); - - // receive a header_ack for the header block. This will also ack the instruction. - recv_instruction(&mut encoder, HEADER_ACK_STREAM_ID_1); - - // now entry can be evicted. - assert!(encoder.encoder.set_max_capacity(10).is_ok()); + test_insertion_blocked_on_waiting_forheader_ack_or_stream_cancel(1); } } diff --git a/third_party/rust/neqo-qpack/src/lib.rs b/third_party/rust/neqo-qpack/src/lib.rs index 2abb2f34cd49..d96ad498208c 100644 --- a/third_party/rust/neqo-qpack/src/lib.rs +++ b/third_party/rust/neqo-qpack/src/lib.rs @@ -5,7 +5,6 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::use_self)] pub mod decoder; pub mod encoder; @@ -26,12 +25,6 @@ enum QPackSide { Decoder, } -impl ::std::fmt::Display for QPackSide { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "{:?}", self) - } -} - #[derive(Clone, Debug, PartialEq)] pub enum Error { DecompressionFailed, @@ -44,7 +37,6 @@ pub enum Error { NoMoreData, IntegerOverflow, WrongStreamCount, - InternalError, TransportError(neqo_transport::Error), } @@ -59,7 +51,7 @@ impl Error { impl ::std::error::Error for Error { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { match self { - Self::TransportError(e) => Some(e), + Error::TransportError(e) => Some(e), _ => None, } } @@ -73,6 +65,6 @@ impl ::std::fmt::Display for Error { impl From for Error { fn from(err: neqo_transport::Error) -> Self { - Self::TransportError(err) + Error::TransportError(err) } } diff --git a/third_party/rust/neqo-qpack/src/table.rs b/third_party/rust/neqo-qpack/src/table.rs index 2f2d8a593441..72aca5b14ebf 100644 --- a/third_party/rust/neqo-qpack/src/table.rs +++ b/third_party/rust/neqo-qpack/src/table.rs @@ -6,41 +6,31 @@ use crate::static_table::{StaticTableEntry, HEADER_STATIC_TABLE}; use crate::{Error, QPackSide, Res}; -use neqo_common::qtrace; -use std::collections::VecDeque; -use std::convert::TryFrom; - -pub struct LookupResult { - pub index: u64, - pub static_table: bool, - pub value_matches: bool, -} +use std::collections::{HashMap, VecDeque}; #[derive(Debug)] pub struct DynamicTableEntry { base: u64, name: Vec, value: Vec, - /// Number of streams that refer this entry. - refs: u64, + refs: HashMap, //TODO multiple header. value will be used for that: or of flags 0x1 for headers, ox2 for trailes. } impl DynamicTableEntry { pub fn can_reduce(&self, first_not_acked: u64) -> bool { - self.refs == 0 && self.base < first_not_acked + self.refs.is_empty() && self.base < first_not_acked } pub fn size(&self) -> u64 { (self.name.len() + self.value.len() + 32) as u64 } - pub fn add_ref(&mut self) { - self.refs += 1; + pub fn add_ref(&mut self, stream_id: u64, _block: u8) { + self.refs.insert(stream_id, 1); } - pub fn remove_ref(&mut self) { - assert!(self.refs > 0); - self.refs -= 1; + pub fn remove_ref(&mut self, stream_id: u64, _block: u8) { + self.refs.remove(&stream_id); } pub fn name(&self) -> &[u8] { @@ -60,26 +50,19 @@ impl DynamicTableEntry { pub struct HeaderTable { qpack_side: QPackSide, dynamic: VecDeque, - // The total capacity (in QPACK bytes) of the table. This is set by + // The total capacity (in HPACK bytes) of the table. This is set by // configuration. capacity: u64, // The amount of used capacity. used: u64, // The total number of inserts thus far. base: u64, - // This is number of inserts that are acked. this correspond to index of the first not acked. acked_inserts_cnt: u64, } -impl ::std::fmt::Display for HeaderTable { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "HeaderTable for {}", self.qpack_side) - } -} - impl HeaderTable { - pub fn new(encoder: bool) -> Self { - Self { + pub fn new(encoder: bool) -> HeaderTable { + HeaderTable { qpack_side: if encoder { QPackSide::Encoder } else { @@ -89,7 +72,7 @@ impl HeaderTable { capacity: 0, used: 0, base: 0, - acked_inserts_cnt: 0, + acked_inserts_cnt: if encoder { 0 } else { std::u64::MAX }, } } @@ -101,13 +84,9 @@ impl HeaderTable { self.capacity } - pub fn set_capacity(&mut self, cap: u64) -> Res<()> { - qtrace!([self], "set capacity to {}", cap); - if !self.evict_to(cap) { - return Err(Error::InternalError); - } - self.capacity = cap; - Ok(()) + pub fn set_capacity(&mut self, c: u64) { + self.evict_to(c); + self.capacity = c; } pub fn get_static(&self, index: u64) -> Res<&StaticTableEntry> { @@ -118,126 +97,75 @@ impl HeaderTable { Ok(res) } - fn get_dynamic_with_abs_index(&mut self, index: u64) -> Res<&mut DynamicTableEntry> { - if self.base <= index { - debug_assert!(false, "This is an iternal error"); - return Err(Error::InternalError); - } - let inx = self.base - index - 1; - let inx = usize::try_from(inx).or(Err(Error::HeaderLookupError))?; - if inx >= self.dynamic.len() { - return Err(Error::HeaderLookupError); - } - Ok(&mut self.dynamic[inx]) - } - - fn get_dynamic_with_relative_index(&self, index: u64) -> Res<&DynamicTableEntry> { - let inx = usize::try_from(index).or(Err(Error::HeaderLookupError))?; - if inx >= self.dynamic.len() { - return Err(Error::HeaderLookupError); - } - Ok(&self.dynamic[inx]) - } - - pub fn get_dynamic(&self, index: u64, base: u64, post: bool) -> Res<&DynamicTableEntry> { + pub fn get_dynamic(&self, i: u64, base: u64, post: bool) -> Res<&DynamicTableEntry> { if self.base < base { return Err(Error::HeaderLookupError); } let inx: u64; let base_rel = self.base - base; if post { - if base_rel <= index { + if base_rel <= i { return Err(Error::HeaderLookupError); } - inx = base_rel - index - 1; + inx = base_rel - i - 1; } else { - inx = base_rel + index; + inx = base_rel + i; } - - self.get_dynamic_with_relative_index(inx) + if inx as usize >= self.dynamic.len() { + return Err(Error::HeaderLookupError); + } + let res = &self.dynamic[inx as usize]; + Ok(res) } - pub fn remove_ref(&mut self, index: u64) { - qtrace!([self], "remove reference to entry {}", index); - self.get_dynamic_with_abs_index(index) - .expect("we should have the entry") - .remove_ref(); + pub fn get_last_added_entry(&mut self) -> Option<&mut DynamicTableEntry> { + self.dynamic.front_mut() } - pub fn add_ref(&mut self, index: u64) { - qtrace!([self], "add reference to entry {}", index); - self.get_dynamic_with_abs_index(index) - .expect("we should have the entry") - .add_ref(); - } - - pub fn lookup(&mut self, name: &[u8], value: &[u8], can_block: bool) -> Option { - qtrace!( - [self], - "lookup name:{:?} value {:?} can_block={}", - name, - value, - can_block - ); - let mut name_match = None; + // separate lookups because static entries can not be return mut and we need dynamic entries mutable. + pub fn lookup( + &mut self, + name: &[u8], + value: &[u8], + ) -> ( + Option<&StaticTableEntry>, + Option<&mut DynamicTableEntry>, + bool, + ) { + let mut name_match_s: Option<&StaticTableEntry> = None; for iter in HEADER_STATIC_TABLE.iter() { if iter.name() == name { if iter.value() == value { - return Some(LookupResult { - index: iter.index(), - static_table: true, - value_matches: true, - }); + return (Some(iter), None, true); } - if name_match.is_none() { - name_match = Some(LookupResult { - index: iter.index(), - static_table: true, - value_matches: false, - }); + if name_match_s.is_none() { + name_match_s = Some(iter); } } } + let mut name_match_d: Option<&mut DynamicTableEntry> = None; for iter in self.dynamic.iter_mut() { - if !can_block && iter.index() >= self.acked_inserts_cnt { - continue; - } if iter.name == name { if iter.value == value { - return Some(LookupResult { - index: iter.index(), - static_table: false, - value_matches: true, - }); + return (None, Some(iter), true); } - if name_match.is_none() { - name_match = Some(LookupResult { - index: iter.index(), - static_table: false, - value_matches: false, - }); + if name_match_s.is_none() && name_match_d.is_none() { + name_match_d = Some(iter); } } } - name_match + + (name_match_s, name_match_d, false) } pub fn evict_to(&mut self, reduce: u64) -> bool { - qtrace!( - [self], - "reduce table to {}, currently used:{}", - reduce, - self.used - ); while (!self.dynamic.is_empty()) && self.used > reduce { if let Some(e) = self.dynamic.back() { - if let QPackSide::Encoder = self.qpack_side { - if !e.can_reduce(self.acked_inserts_cnt) { - return false; - } + if !e.can_reduce(self.acked_inserts_cnt) { + return false; } self.used -= e.size(); self.dynamic.pop_back(); @@ -246,14 +174,14 @@ impl HeaderTable { true } - pub fn insert(&mut self, name: &[u8], value: &[u8]) -> Res { - qtrace!([self], "insert name={:?} value={:?}", name, value); + pub fn insert(&mut self, name: Vec, value: Vec) -> Res<()> { let entry = DynamicTableEntry { - name: name.to_vec(), - value: value.to_vec(), + name, + value, base: self.base, - refs: 0, + refs: HashMap::new(), }; + if entry.size() > self.capacity || !self.evict_to(self.capacity - entry.size()) { match self.qpack_side { QPackSide::Encoder => return Err(Error::EncoderStreamError), @@ -262,64 +190,51 @@ impl HeaderTable { } self.base += 1; self.used += entry.size(); - let index = entry.index(); self.dynamic.push_front(entry); - Ok(index) + Ok(()) } pub fn insert_with_name_ref( &mut self, name_static_table: bool, name_index: u64, - value: &[u8], + value: Vec, ) -> Res<()> { - qtrace!( - [self], - "insert with ref to index={} in {} value={:?}", - name_index, - if name_static_table { - "static" - } else { - "dynamic" - }, - value - ); - let name = if name_static_table { - self.get_static(name_index)?.name().to_vec() + let name; + if name_static_table { + let entry = self.get_static(name_index)?; + name = entry.name().to_vec() } else { - self.get_dynamic(name_index, self.base, false)? - .name() - .to_vec() - }; - self.insert(&name, value)?; - Ok(()) + let entry = self.get_dynamic(name_index, self.base, false)?; + name = entry.name().to_vec(); + } + self.insert(name, value) } pub fn duplicate(&mut self, index: u64) -> Res<()> { - qtrace!([self], "dumplicate entry={}", index); - // need to remember name and value because insert may delete the entry. + // need to remember name and value because inser may delete the entry. let name: Vec; let value: Vec; { let entry = self.get_dynamic(index, self.base, false)?; name = entry.name().to_vec(); value = entry.value().to_vec(); - qtrace!([self], "dumplicate name={:?} value={:?}", name, value); } - self.insert(&name, &value)?; + self.insert(name, value)?; Ok(()) } - pub fn increment_acked(&mut self, increment: u64) -> Res<()> { - qtrace!([self], "increment acked by {}", increment); + pub fn increment_acked(&mut self, increment: u64) { self.acked_inserts_cnt += increment; - if self.base < self.acked_inserts_cnt { - return Err(Error::InternalError); - } - Ok(()) } pub fn get_acked_inserts_cnt(&self) -> u64 { self.acked_inserts_cnt } + + pub fn header_ack(&mut self, stream_id: u64) { + for iter in self.dynamic.iter_mut() { + iter.remove_ref(stream_id, 1); + } + } } diff --git a/third_party/rust/neqo-transport/.cargo-checksum.json b/third_party/rust/neqo-transport/.cargo-checksum.json index 5cf9e4779358..a118f9e34221 100644 --- a/third_party/rust/neqo-transport/.cargo-checksum.json +++ b/third_party/rust/neqo-transport/.cargo-checksum.json @@ -1 +1 @@ -{"files":{"Cargo.toml":"0eda0a1e898b0294949969055376633e22bd973d5ba3525f88a2fc5ef3afc0cd","TODO":"d759cb804b32fa9d96ea8d3574a3c4073da9fe6a0b02b708a0e22cce5a5b4a0f","src/cc.rs":"9fa6bebf5fc6d9dab2b53cff38f1ea88e6cc20b1c7dc9c55a30bc75599306c34","src/cid.rs":"4161a5add9285a9f670c4d0b88ac84833b710cd99cbd5ec080f4d2f097200abf","src/connection.rs":"5300a6a55bf32fd3abdd0857455ea10a28b93e1f6ae69ef81a283915b90be49d","src/crypto.rs":"78558a8312969285d3082574e31abd349f964590a66167352da8adec7e9c6ed2","src/dump.rs":"d69ccb0e3b240823b886a791186afbac9f2e26d1f1f67a55dbf86f8cd3e6203e","src/events.rs":"07b1fa18efc538b96736ebfedba929b4854dffd460e1250ae02dc79cc86bb310","src/flow_mgr.rs":"0b1c6e7587e411635723207ecface6c62d1243599dd017745c61eb94805b9886","src/frame.rs":"2859a30e4889fd6b4124b9f88affcec5956f8f3914ccc7684525bfad085ef076","src/lib.rs":"dbaaf47b1025a5d976ceff86989e6d8729e993e525a3ef1d59046d45c0a09bf1","src/packet.rs":"a3b0b0414e8ddaddfe098fa343a89449b42fbb1ae44468d04994becebd7ab5cc","src/recovery.rs":"887a963dbc6e987caba0d74c0ce6b71212b96f87078cb1086ded560c4b930834","src/recv_stream.rs":"0a0c44a3414e6088a704c1a245fd98dd8a8ed502d80bfab90d2defe039bd37cb","src/send_stream.rs":"0c3e401bb6a7ea7babe47234d7a05c993f4a3c67f07f4130d81a8476085e746f","src/server.rs":"ca82b8bbfae29cf2fb6aadf4298d8482a734edbd460e569766d16b596aac0554","src/stats.rs":"a276bd9902939091578d3bcc95aa7dd0b32f5d29e528c12e7b93a0ab88474478","src/stream_id.rs":"b3158cf2c6072da79bf6e77a31f71f2f3b970429221142a9ff1dd6cd07df2442","src/tparams.rs":"4b328b0b146f06d805fbc77748f9cb578f9ee4a0c63565158bf36d7bd3151020","src/tracking.rs":"0df46cd244afc32ca3aef1dcd8eb9abcce364a330bd8053a2484740bb5b2b3fd","tests/conn_vectors.rs":"d42db769518162bb3e39667a999ab467541c72d7a7d42e92adbd39c16eca0811","tests/connection.rs":"a93985c199a9ef987106f4a20b35ebf803cdbbb855c07b1362b403eed7101ef8","tests/server.rs":"d516bf134a63377c98ff4ac2cca2d4cc6562f48ea262b10a23866d38c4c82af3"},"package":null} \ No newline at end of file +{"files":{"Cargo.toml":"e0f8a00f0862504bdc858abae1599f0403efb3cb69b54a3d0ea1031891790698","TODO":"d759cb804b32fa9d96ea8d3574a3c4073da9fe6a0b02b708a0e22cce5a5b4a0f","src/connection.rs":"e3af2eb9ab351538399e60f684641b6471f0d70e7c58543ed272d1f63176c4bd","src/crypto.rs":"606b705d2c91591bf91b56910f9839804ffa00d64e1b777562470c8a388ab86e","src/dump.rs":"e4058d89acf50c3dea7e9f067e6fa1d8abfe0a65a77acf187f5012b53ed2568d","src/events.rs":"07b1fa18efc538b96736ebfedba929b4854dffd460e1250ae02dc79cc86bb310","src/flow_mgr.rs":"a85aebc35358258ff5ede98cbc41a4af57d4f5528d3a168bcedbb0ff86fc9660","src/frame.rs":"13850d329895d3a44d4ba4f99ea4a0cd8f6b325361505f41ac373278e7a57f9e","src/lib.rs":"b2b8a2f67c96305870b05d7b73cff50e0f347061493480ed7a77f86b5e48149d","src/packet.rs":"9cb94fc6031d7f9590de53d6b3260b9d43fae297837a527752422453b8099436","src/recovery.rs":"66e92afd09c2aa97f606262ca65b6dfbc7a2a5f73aee9ae452d50564aea39a09","src/recv_stream.rs":"caed6677abc1bbd08ce57abf7182b6a151c31c7cb660622344ac50e96ab58653","src/send_stream.rs":"ef450a2e3e51815f50cb5016c257a02d0161a876519214cf4a71eb5bce54aa89","src/server.rs":"d391a1d585bb1e45d025cdd1adb25f986128302178926b71e17dce8105346dda","src/stats.rs":"dca5afcb6252f3f32f494513f76964cffb945afd6d18b8669dea98a7aeed1689","src/stream_id.rs":"b3158cf2c6072da79bf6e77a31f71f2f3b970429221142a9ff1dd6cd07df2442","src/tparams.rs":"325be72a070b22e03a5e9bf16c65fae4bf1835ca8b04d36b78baea3daee340f8","src/tracking.rs":"7bd00282689b7cb8d96c44105f83cd5739b46683d4aba6f58a5328d53ed18fb0","tests/conn_vectors.rs":"c594ea1c65ded6281ae1cc900cc4afa0143daa5c004adadbe29eb1ab900d8d71","tests/connection.rs":"e86725e3f59b30b9ea70e7961a5ded3ed36522f3bfcf2b7df1f502e18863121c","tests/server.rs":"938aed8ef27d5ff2dc01bf84ecb31691943c3ebe6299eab43f266de935c542da"},"package":null} \ No newline at end of file diff --git a/third_party/rust/neqo-transport/Cargo.toml b/third_party/rust/neqo-transport/Cargo.toml index ba8fde1a7c40..cf2fc1e13616 100644 --- a/third_party/rust/neqo-transport/Cargo.toml +++ b/third_party/rust/neqo-transport/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neqo-transport" -version = "0.1.13" +version = "0.1.12" authors = ["EKR ", "Andy Grover "] edition = "2018" license = "MIT/Apache-2.0" @@ -8,7 +8,8 @@ license = "MIT/Apache-2.0" [dependencies] neqo-crypto = { path = "../neqo-crypto" } neqo-common = { path = "../neqo-common" } -lazy_static = "1.3.0" +lazy_static = "1.0" +rand = "0.7" log = "0.4.0" smallvec = "1.0.0" diff --git a/third_party/rust/neqo-transport/src/cc.rs b/third_party/rust/neqo-transport/src/cc.rs deleted file mode 100644 index 80d46f735b66..000000000000 --- a/third_party/rust/neqo-transport/src/cc.rs +++ /dev/null @@ -1,196 +0,0 @@ -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Congestion control - -use std::cmp::max; -use std::fmt::{self, Display}; -use std::time::{Duration, Instant}; - -use crate::tracking::SentPacket; -use neqo_common::{const_max, const_min, qdebug, qinfo, qtrace}; - -pub const MAX_DATAGRAM_SIZE: usize = 1232; // For ipv6, smaller than ipv4 (1252) -pub const INITIAL_CWND_PKTS: usize = 10; -const INITIAL_WINDOW: usize = const_min( - INITIAL_CWND_PKTS * MAX_DATAGRAM_SIZE, - const_max(2 * MAX_DATAGRAM_SIZE, 14720), -); -pub const MIN_CONG_WINDOW: usize = MAX_DATAGRAM_SIZE * 2; -const PERSISTENT_CONG_THRESH: u32 = 3; - -#[derive(Debug)] -pub struct CongestionControl { - congestion_window: usize, // = kInitialWindow - bytes_in_flight: usize, - congestion_recovery_start_time: Option, - ssthresh: usize, -} - -impl Default for CongestionControl { - fn default() -> Self { - Self { - congestion_window: INITIAL_WINDOW, - bytes_in_flight: 0, - congestion_recovery_start_time: None, - ssthresh: std::usize::MAX, - } - } -} - -impl Display for CongestionControl { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "CongCtrl {}/{} ssthresh {}", - self.bytes_in_flight, self.congestion_window, self.ssthresh - ) - } -} - -impl CongestionControl { - #[cfg(test)] - #[must_use] - pub fn cwnd(&self) -> usize { - self.congestion_window - } - - #[cfg(test)] - #[must_use] - pub fn ssthresh(&self) -> usize { - self.ssthresh - } - - #[must_use] - pub fn cwnd_avail(&self) -> usize { - // BIF can be higher than cwnd due to PTO packets, which are sent even - // if avail is 0, but still count towards BIF. - self.congestion_window.saturating_sub(self.bytes_in_flight) - } - - // Multi-packet version of OnPacketAckedCC - pub fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) { - for pkt in acked_pkts - .iter() - .filter(|pkt| pkt.in_flight) - .filter(|pkt| pkt.time_declared_lost.is_none()) - { - assert!(self.bytes_in_flight >= pkt.size); - self.bytes_in_flight -= pkt.size; - - if self.in_congestion_recovery(pkt.time_sent) { - // Do not increase congestion window in recovery period. - continue; - } - if self.app_limited() { - // Do not increase congestion_window if application limited. - continue; - } - - if self.congestion_window < self.ssthresh { - self.congestion_window += pkt.size; - qinfo!([self], "slow start"); - } else { - self.congestion_window += (MAX_DATAGRAM_SIZE * pkt.size) / self.congestion_window; - qinfo!([self], "congestion avoidance"); - } - } - } - - pub fn on_packets_lost( - &mut self, - now: Instant, - largest_acked_sent: Option, - pto: Duration, - lost_packets: &[SentPacket], - ) { - if lost_packets.is_empty() { - return; - } - - for pkt in lost_packets.iter().filter(|pkt| pkt.in_flight) { - assert!(self.bytes_in_flight >= pkt.size); - self.bytes_in_flight -= pkt.size; - } - - qdebug!([self], "Pkts lost {}", lost_packets.len()); - - let last_lost_pkt = lost_packets.last().unwrap(); - self.on_congestion_event(now, last_lost_pkt.time_sent); - - let in_persistent_congestion = { - let congestion_period = pto * PERSISTENT_CONG_THRESH; - - match largest_acked_sent { - Some(las) => las < last_lost_pkt.time_sent - congestion_period, - None => { - // Nothing has ever been acked. Could still be PC. - let first_lost_pkt_sent = lost_packets.first().unwrap().time_sent; - last_lost_pkt.time_sent - first_lost_pkt_sent > congestion_period - } - } - }; - if in_persistent_congestion { - qinfo!([self], "persistent congestion"); - self.congestion_window = MIN_CONG_WINDOW; - } - } - - pub fn discard(&mut self, pkt: &SentPacket) { - if pkt.in_flight { - assert!(self.bytes_in_flight >= pkt.size); - self.bytes_in_flight -= pkt.size; - qtrace!([self], "Ignore pkt with size {}", pkt.size); - } - } - - pub fn on_packet_sent(&mut self, pkt: &SentPacket) { - if !pkt.in_flight { - return; - } - - self.bytes_in_flight += pkt.size; - qdebug!( - [self], - "Pkt Sent len {}, bif {}, cwnd {}", - pkt.size, - self.bytes_in_flight, - self.congestion_window - ); - debug_assert!(self.bytes_in_flight <= self.congestion_window); - } - - #[must_use] - pub fn in_congestion_recovery(&self, sent_time: Instant) -> bool { - self.congestion_recovery_start_time - .map(|start| sent_time <= start) - .unwrap_or(false) - } - - fn on_congestion_event(&mut self, now: Instant, sent_time: Instant) { - // Start a new congestion event if packet was sent after the - // start of the previous congestion recovery period. - if !self.in_congestion_recovery(sent_time) { - self.congestion_recovery_start_time = Some(now); - self.congestion_window /= 2; // kLossReductionFactor = 0.5 - self.congestion_window = max(self.congestion_window, MIN_CONG_WINDOW); - self.ssthresh = self.congestion_window; - qinfo!( - [self], - "Cong event -> recovery; cwnd {}, ssthresh {}", - self.congestion_window, - self.ssthresh - ); - } else { - qdebug!([self], "Cong event but already in recovery"); - } - } - - fn app_limited(&self) -> bool { - //TODO(agrover): how do we get this info?? - false - } -} diff --git a/third_party/rust/neqo-transport/src/cid.rs b/third_party/rust/neqo-transport/src/cid.rs deleted file mode 100644 index d6ba365b6a2f..000000000000 --- a/third_party/rust/neqo-transport/src/cid.rs +++ /dev/null @@ -1,149 +0,0 @@ -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Encoding and decoding packets off the wire. - -use neqo_common::{hex, matches, Decoder}; -use neqo_crypto::random; - -use std::borrow::Borrow; -use std::cmp::max; - -#[derive(Clone, Default, Eq, Hash, PartialEq)] -pub struct ConnectionId { - pub(crate) cid: Vec, -} - -impl ConnectionId { - pub fn generate(len: usize) -> Self { - assert!(matches!(len, 0..=20)); - Self { cid: random(len) } - } - - // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long. - pub fn generate_initial() -> Self { - let v = random(1); - // Bias selection toward picking 8 (>50% of the time). - let len: usize = max(8, 5 + (v[0] & (v[0] >> 4))).into(); - Self::generate(len) - } - - pub fn as_ref(&self) -> ConnectionIdRef { - ConnectionIdRef::from(&self.cid[..]) - } -} - -impl Borrow<[u8]> for ConnectionId { - fn borrow(&self) -> &[u8] { - &self.cid - } -} - -impl From<&[u8]> for ConnectionId { - fn from(buf: &[u8]) -> Self { - Self { - cid: Vec::from(buf), - } - } -} - -impl<'a> From<&ConnectionIdRef<'a>> for ConnectionId { - fn from(cidref: &ConnectionIdRef<'a>) -> Self { - Self { - cid: Vec::from(cidref.cid), - } - } -} - -impl std::ops::Deref for ConnectionId { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.cid - } -} - -impl ::std::fmt::Debug for ConnectionId { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "CID {}", hex(&self.cid)) - } -} - -impl ::std::fmt::Display for ConnectionId { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "{}", hex(&self.cid)) - } -} - -impl<'a> PartialEq> for ConnectionId { - fn eq(&self, other: &ConnectionIdRef<'a>) -> bool { - &self.cid[..] == other.cid - } -} - -#[derive(Hash, Eq, PartialEq)] -pub struct ConnectionIdRef<'a> { - cid: &'a [u8], -} - -impl<'a> ::std::fmt::Debug for ConnectionIdRef<'a> { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "CID {}", hex(&self.cid)) - } -} - -impl<'a> ::std::fmt::Display for ConnectionIdRef<'a> { - fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "{}", hex(&self.cid)) - } -} - -impl<'a> From<&'a [u8]> for ConnectionIdRef<'a> { - fn from(cid: &'a [u8]) -> Self { - Self { cid } - } -} - -impl<'a> std::ops::Deref for ConnectionIdRef<'a> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.cid - } -} - -impl<'a> PartialEq for ConnectionIdRef<'a> { - fn eq(&self, other: &ConnectionId) -> bool { - self.cid == &other.cid[..] - } -} - -pub trait ConnectionIdDecoder { - fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option>; -} - -pub trait ConnectionIdManager: ConnectionIdDecoder { - fn generate_cid(&mut self) -> ConnectionId; - fn as_decoder(&self) -> &dyn ConnectionIdDecoder; -} - -#[cfg(test)] -mod tests { - use super::*; - use neqo_common::matches; - use test_fixture::fixture_init; - - #[test] - fn generate_initial_cid() { - fixture_init(); - for _ in 0..100 { - let cid = ConnectionId::generate_initial(); - if !matches!(cid.len(), 8..=20) { - panic!("connection ID {:?}", cid); - } - } - } -} diff --git a/third_party/rust/neqo-transport/src/connection.rs b/third_party/rust/neqo-transport/src/connection.rs index 15b65fd7268c..be7d795eebc4 100644 --- a/third_party/rust/neqo-transport/src/connection.rs +++ b/third_party/rust/neqo-transport/src/connection.rs @@ -10,6 +10,7 @@ use std::cell::RefCell; use std::cmp::{max, min, Ordering}; use std::collections::HashMap; use std::convert::TryFrom; +use std::convert::TryInto; use std::fmt::{self, Debug}; use std::net::SocketAddr; use std::rc::Rc; @@ -20,18 +21,22 @@ use smallvec::SmallVec; use neqo_common::{hex, matches, qdebug, qerror, qinfo, qtrace, qwarn, Datagram, Decoder, Encoder}; use neqo_crypto::agent::CertificateInfo; use neqo_crypto::{ - Agent, AntiReplay, AuthenticationStatus, Client, HandshakeState, Record, SecretAgentInfo, - Server, + Agent, AntiReplay, AuthenticationStatus, Client, Epoch, HandshakeState, Record, + SecretAgentInfo, Server, }; -use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef}; -use crate::crypto::{Crypto, CryptoDxState}; +use crate::crypto::{Crypto, CryptoDxDirection, CryptoDxState, CryptoState}; use crate::dump::*; use crate::events::{ConnectionEvent, ConnectionEvents}; use crate::flow_mgr::FlowMgr; -use crate::frame::{AckRange, Frame, FrameType, StreamType, TxMode}; -use crate::packet::{DecryptedPacket, PacketBuilder, PacketNumber, PacketType, PublicPacket}; -use crate::recovery::{LossRecovery, RecoveryToken}; +use crate::frame::{decode_frame, AckRange, Frame, FrameType, StreamType, TxMode}; +use crate::packet::{ + decode_packet_hdr, decrypt_packet, encode_packet, ConnectionId, ConnectionIdDecoder, PacketHdr, + PacketNumberDecoder, PacketType, +}; +use crate::recovery::{ + LossRecovery, LossRecoveryMode, LossRecoveryState, RecoveryToken, SentPacket, +}; use crate::recv_stream::{RecvStream, RecvStreams, RX_STREAM_DATA_WINDOW}; use crate::send_stream::{SendStream, SendStreams}; use crate::stats::Stats; @@ -39,17 +44,22 @@ use crate::stream_id::{StreamId, StreamIndex, StreamIndexes}; use crate::tparams::{ tp_constants, TransportParameter, TransportParameters, TransportParametersHandler, }; -use crate::tracking::{AckTracker, PNSpace, SentPacket}; -use crate::{AppError, ConnectionError, Error, Res, LOCAL_IDLE_TIMEOUT}; +use crate::tracking::{AckTracker, PNSpace}; +use crate::QUIC_VERSION; +use crate::{AppError, ConnectionError, Error, Res}; #[derive(Debug, Default)] struct Packet(Vec); +const NUM_EPOCHS: Epoch = 4; + pub const LOCAL_STREAM_LIMIT_BIDI: u64 = 16; pub const LOCAL_STREAM_LIMIT_UNI: u64 = 16; const LOCAL_MAX_DATA: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62-1 +const LOCAL_IDLE_TIMEOUT: Duration = Duration::from_secs(60); // 1 minute + #[derive(Debug, PartialEq, Copy, Clone)] /// Client or Server. pub enum Role { @@ -60,8 +70,8 @@ pub enum Role { impl Role { pub fn remote(self) -> Self { match self { - Self::Client => Self::Server, - Self::Server => Self::Client, + Role::Client => Role::Server, + Role::Server => Role::Client, } } } @@ -79,7 +89,6 @@ pub enum State { WaitInitial, Handshaking, Connected, - Confirmed, Closing { error: ConnectionError, frame_type: FrameType, @@ -89,13 +98,6 @@ pub enum State { Closed(ConnectionError), } -impl State { - #[must_use] - pub fn connected(&self) -> bool { - matches!(self, Self::Connected | Self::Confirmed) - } -} - // Implement Ord so that we can enforce monotonic state progression. impl PartialOrd for State { #[allow(clippy::match_same_arms)] // Lint bug: rust-lang/rust-clippy#860 @@ -104,19 +106,17 @@ impl PartialOrd for State { return Some(Ordering::Equal); } Some(match (self, other) { - (Self::Init, _) => Ordering::Less, - (_, Self::Init) => Ordering::Greater, - (Self::WaitInitial, _) => Ordering::Less, - (_, Self::WaitInitial) => Ordering::Greater, - (Self::Handshaking, _) => Ordering::Less, - (_, Self::Handshaking) => Ordering::Greater, - (Self::Connected, _) => Ordering::Less, - (_, Self::Connected) => Ordering::Greater, - (Self::Confirmed, _) => Ordering::Less, - (_, Self::Confirmed) => Ordering::Greater, - (Self::Closing { .. }, _) => Ordering::Less, - (_, Self::Closing { .. }) => Ordering::Greater, - (Self::Closed(_), _) => unreachable!(), + (State::Init, _) => Ordering::Less, + (_, State::Init) => Ordering::Greater, + (State::WaitInitial, _) => Ordering::Less, + (_, State::WaitInitial) => Ordering::Greater, + (State::Handshaking, _) => Ordering::Less, + (_, State::Handshaking) => Ordering::Greater, + (State::Connected, _) => Ordering::Less, + (_, State::Connected) => Ordering::Greater, + (State::Closing { .. }, _) => Ordering::Less, + (_, State::Closing { .. }) => Ordering::Greater, + (State::Closed(_), _) => unreachable!(), }) } } @@ -124,9 +124,9 @@ impl PartialOrd for State { #[derive(Debug)] enum ZeroRttState { Init, - Sending, + Sending(CryptoDxState), AcceptedClient, - AcceptedServer, + AcceptedServer(CryptoDxState), Rejected, } @@ -177,10 +177,9 @@ pub enum Output { impl Output { /// Convert into an `Option`. - #[must_use] pub fn dgram(self) -> Option { match self { - Self::Datagram(dg) => Some(dg), + Output::Datagram(dg) => Some(dg), _ => None, } } @@ -188,21 +187,16 @@ impl Output { /// Get a reference to the Datagram, if any. pub fn as_dgram_ref(&self) -> Option<&Datagram> { match self { - Self::Datagram(dg) => Some(dg), + Output::Datagram(dg) => Some(dg), _ => None, } } - - /// Ask how long the caller should wait before calling back. - #[must_use] - pub fn callback(&self) -> Duration { - match self { - Self::Callback(t) => *t, - _ => Duration::new(0, 0), - } - } } +pub trait ConnectionIdManager: ConnectionIdDecoder { + fn generate_cid(&mut self) -> ConnectionId; + fn as_decoder(&self) -> &dyn ConnectionIdDecoder; +} /// Alias the common form for ConnectionIdManager. type CidMgr = Rc>; @@ -216,8 +210,8 @@ impl FixedConnectionIdManager { } } impl ConnectionIdDecoder for FixedConnectionIdManager { - fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option> { - dec.decode(self.len).map(ConnectionIdRef::from) + fn decode_cid(&self, dec: &mut Decoder) -> Option { + dec.decode(self.len).map(ConnectionId::from) } } impl ConnectionIdManager for FixedConnectionIdManager { @@ -234,15 +228,6 @@ struct RetryInfo { odcid: ConnectionId, } -impl RetryInfo { - fn new(odcid: ConnectionId) -> Self { - Self { - token: Vec::new(), - odcid, - } - } -} - #[derive(Debug, Clone)] /// There's a little bit of different behavior for resetting idle timeout. See /// -transport 10.2 ("Idle Timeout"). @@ -254,15 +239,15 @@ enum IdleTimeout { impl Default for IdleTimeout { fn default() -> Self { - Self::Init + IdleTimeout::Init } } impl IdleTimeout { pub fn as_instant(&self) -> Option { match self { - Self::Init => None, - Self::PacketReceived(t) | Self::AckElicitingPacketSent(t) => Some(*t), + IdleTimeout::Init => None, + IdleTimeout::PacketReceived(t) | IdleTimeout::AckElicitingPacketSent(t) => Some(*t), } } @@ -270,15 +255,15 @@ impl IdleTimeout { // Only reset idle timeout if we've received a packet since the last // time we reset the timeout here. match self { - Self::AckElicitingPacketSent(_) => {} - Self::Init | Self::PacketReceived(_) => { - *self = Self::AckElicitingPacketSent(now + LOCAL_IDLE_TIMEOUT); + IdleTimeout::AckElicitingPacketSent(_) => {} + IdleTimeout::Init | IdleTimeout::PacketReceived(_) => { + *self = IdleTimeout::AckElicitingPacketSent(now + LOCAL_IDLE_TIMEOUT); } } } fn on_packet_received(&mut self, now: Instant) { - *self = Self::PacketReceived(now + LOCAL_IDLE_TIMEOUT); + *self = IdleTimeout::PacketReceived(now + LOCAL_IDLE_TIMEOUT); } pub fn expired(&self, now: Instant) -> bool { @@ -290,53 +275,6 @@ impl IdleTimeout { } } -/// StateManagement manages whether we need to send HANDSHAKE_DONE and CONNECTION_CLOSE. -/// Valid state transitions are: -/// * Idle -> HandshakeDone: at the server when the handshake completes -/// * HandshakeDone -> Idle: when a HANDSHAKE_DONE frame is sent -/// * Idle/HandshakeDone -> ConnectionClose: when closing -/// * ConnectionClose -> CloseSent: after sending CONNECTION_CLOSE -/// * CloseSent -> ConnectionClose: any time a new CONNECTION_CLOSE is needed -#[derive(Debug, Clone, PartialEq)] -enum StateSignaling { - Idle, - HandshakeDone, - ConnectionClose, - CloseSent, -} - -impl StateSignaling { - pub fn handshake_done(&mut self) { - if *self != Self::Idle { - debug_assert!(false, "StateSignaling must be in Idle state."); - return; - } - *self = Self::HandshakeDone - } - - pub fn send_done(&mut self) -> Option<(Frame, Option)> { - if *self == Self::HandshakeDone { - *self = Self::Idle; - Some((Frame::HandshakeDone, Some(RecoveryToken::HandshakeDone))) - } else { - None - } - } - - pub fn closing(&self) -> bool { - *self == Self::ConnectionClose - } - - pub fn close(&mut self) { - *self = Self::ConnectionClose - } - - pub fn close_sent(&mut self) { - debug_assert!(self.closing()); - *self = Self::CloseSent - } -} - /// A QUIC Connection /// /// First, create a new connection using `new_client()` or `new_server()`. @@ -353,6 +291,7 @@ impl StateSignaling { /// After the connection is closed (either by calling `close()` or by the /// remote) continue processing until `state()` returns `Closed`. pub struct Connection { + version: crate::packet::Version, role: Role, state: State, tps: Rc>, @@ -375,20 +314,20 @@ pub struct Connection { pub(crate) send_streams: SendStreams, pub(crate) recv_streams: RecvStreams, pub(crate) flow_mgr: Rc>, - state_signaling: StateSignaling, loss_recovery: LossRecovery, + loss_recovery_state: LossRecoveryState, events: ConnectionEvents, token: Option>, stats: Stats, + tx_mode: TxMode, } impl Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, + f.write_fmt(format_args!( "{:?} Connection: {:?} {:?}", self.role, self.state, self.path - ) + )) } } @@ -416,8 +355,7 @@ impl Connection { remote_cid: dcid.clone(), }), ); - c.crypto.states.init(Role::Client, &dcid); - c.retry_info = Some(RetryInfo::new(dcid)); + c.crypto.create_initial_state(Role::Client, &dcid); Ok(c) } @@ -462,13 +400,13 @@ impl Connection { tps.set_integer(tp_constants::INITIAL_MAX_DATA, LOCAL_MAX_DATA); tps.set_integer( tp_constants::IDLE_TIMEOUT, - u64::try_from(LOCAL_IDLE_TIMEOUT.as_millis()).unwrap(), + LOCAL_IDLE_TIMEOUT.as_millis().try_into().unwrap(), ); tps.set_empty(tp_constants::DISABLE_MIGRATION); } fn new( - role: Role, + r: Role, agent: Agent, cid_manager: CidMgr, anti_replay: Option<&AntiReplay>, @@ -481,8 +419,9 @@ impl Connection { .expect("TLS should be configured successfully"); Self { - role, - state: match role { + version: QUIC_VERSION, + role: r, + state: match r { Role::Client => State::Init, Role::Server => State::WaitInitial, }, @@ -500,11 +439,12 @@ impl Connection { send_streams: SendStreams::default(), recv_streams: RecvStreams::default(), flow_mgr: Rc::new(RefCell::new(FlowMgr::default())), - state_signaling: StateSignaling::Idle, loss_recovery: LossRecovery::new(), + loss_recovery_state: LossRecoveryState::default(), events: ConnectionEvents::default(), token: None, stats: Stats::default(), + tx_mode: TxMode::Normal, } } @@ -541,7 +481,7 @@ impl Connection { /// Access the latest resumption token on the connection. pub fn resumption_token(&self) -> Option> { - if !self.state.connected() { + if self.state != State::Connected { return None; } match self.crypto.tls { @@ -631,7 +571,7 @@ impl Connection { /// Call by application when the peer cert has been verified pub fn authenticated(&mut self, status: AuthenticationStatus, now: Instant) { self.crypto.tls.authenticated(status); - let res = self.handshake(now, PNSpace::Handshake, None); + let res = self.handshake(now, 0, None); self.absorb_error(now, res); } @@ -684,16 +624,13 @@ impl Connection { return; } - let res = self.crypto.states.check_key_update(now); - self.absorb_error(now, res); - if self.idle_timeout.expired(now) { qinfo!("idle timeout expired"); self.set_state(State::Closed(ConnectionError::Transport( Error::IdleTimeout, ))); - } else if let Some(packets) = self.loss_recovery.check_loss_detection_timeout(now) { - self.handle_lost_packets(&packets); + } else { + self.check_loss_detection_timeout(now); } } @@ -707,7 +644,7 @@ impl Connection { /// Just like above but returns frames parsed from the datagram #[cfg(test)] - pub fn test_process_input(&mut self, dgram: Datagram, now: Instant) -> Vec<(Frame, PNSpace)> { + pub fn test_process_input(&mut self, dgram: Datagram, now: Instant) -> Vec<(Frame, Epoch)> { let res = self.input(dgram, now); let frames = self.absorb_error(now, res).unwrap_or_default(); self.cleanup_streams(); @@ -716,40 +653,28 @@ impl Connection { /// Get the time that we next need to be called back, relative to `now`. fn next_delay(&mut self, now: Instant) -> Duration { - qtrace!([self], "Get callback delay {:?}", now); + self.loss_recovery_state = self.loss_recovery.get_timer(); + let mut delays = SmallVec::<[_; 4]>::new(); - if let Some(lr_time) = self.loss_recovery.calculate_timer() { - qtrace!([self], "Loss recovery timer {:?}", lr_time); + if let Some(lr_time) = self.loss_recovery_state.callback_time() { delays.push(lr_time); } if let Some(ack_time) = self.acks.ack_time() { - qtrace!([self], "Delayed ACK timer {:?}", ack_time); delays.push(ack_time); } if let Some(idle_time) = self.idle_timeout.as_instant() { - qtrace!([self], "Idle timer {:?}", idle_time); delays.push(idle_time); } - if let Some(key_update_time) = self.crypto.states.update_time() { - qtrace!([self], "Key update timer {:?}", key_update_time); - delays.push(key_update_time); - } - // Should always at least have idle timeout, once connected assert!(!delays.is_empty()); let earliest = delays.into_iter().min().unwrap(); // TODO(agrover, mt) - need to analyze and fix #47 // rather than just clamping to zero here. - qdebug!( - [self], - "delay duration {:?}", - max(now, earliest).duration_since(now) - ); max(now, earliest).duration_since(now) } @@ -797,173 +722,221 @@ impl Connection { self.process_output(now) } - fn is_valid_cid(&self, cid: &ConnectionIdRef) -> bool { - let check = |c| c == cid; - self.valid_cids.iter().any(check) - || self.path.iter().any(|p| p.local_cids.iter().any(check)) + fn is_valid_cid(&self, cid: &ConnectionId) -> bool { + self.valid_cids.contains(cid) || self.path.iter().any(|p| p.local_cids.contains(cid)) } - fn handle_retry(&mut self, packet: PublicPacket) -> Res<()> { - qdebug!([self], "received Retry"); - debug_assert!(self.retry_info.is_some()); - if !self.retry_info.as_ref().unwrap().token.is_empty() { - qinfo!([self], "Dropping extra Retry"); - self.stats.dropped_rx += 1; - return Ok(()); - } - if packet.token().is_empty() { - qinfo!([self], "Dropping Retry without a token"); - self.stats.dropped_rx += 1; - return Ok(()); - } - if !packet.is_valid_retry(&self.retry_info.as_ref().unwrap().odcid) { - qinfo!([self], "Dropping Retry with bad integrity tag"); - self.stats.dropped_rx += 1; - return Ok(()); - } - if let Some(p) = &mut self.path { - // At this point, we shouldn't have a remote connection ID for the path. - p.remote_cid = ConnectionId::from(packet.scid()); + fn is_valid_initial(&self, hdr: &PacketHdr) -> bool { + if let PacketType::Initial(token) = &hdr.tipe { + // Server checks the token, so if we have one, + // assume that the DCID is OK. + if hdr.dcid.len() < 8 { + if token.is_empty() { + qinfo!([self], "Drop Initial with short DCID"); + false + } else { + qinfo!([self], "Initial received with token, assuming OK"); + true + } + } else { + // This is the normal path. Don't log. + true + } } else { - qinfo!([self], "No path, but we received a Retry"); - return Err(Error::InternalError); - }; - self.retry_info.as_mut().unwrap().token = packet.token().to_vec(); + qdebug!([self], "Dropping non-Initial packet"); + false + } + } + + fn handle_retry(&mut self, scid: &ConnectionId, odcid: &ConnectionId, token: &[u8]) -> Res<()> { + qdebug!([self], "received Retry"); + if self.retry_info.is_some() { + qinfo!([self], "Dropping extra Retry"); + return Ok(()); + } + if token.is_empty() { + qinfo!([self], "Dropping Retry without a token"); + return Ok(()); + } + match self.path.iter_mut().find(|p| p.remote_cid == *odcid) { + None => { + qinfo!([self], "Ignoring Retry with mismatched ODCID"); + return Ok(()); + } + Some(path) => { + path.remote_cid = scid.clone(); + } + } qinfo!( [self], - "Valid Retry received, token={}", - hex(packet.token()) + "Valid Retry received, restarting with provided token" ); + self.retry_info = Some(RetryInfo { + token: token.to_vec(), + odcid: odcid.clone(), + }); let lost_packets = self.loss_recovery.retry(); self.handle_lost_packets(&lost_packets); // Switching crypto state here might not happen eventually. // https://github.com/quicwg/base-drafts/issues/2823 - self.crypto.states.init(self.role, packet.scid()); + self.crypto.create_initial_state(self.role, scid); Ok(()) } - fn discard_keys(&mut self, space: PNSpace) { - self.loss_recovery.discard(space); - self.crypto.discard(space); - } - - fn input(&mut self, d: Datagram, now: Instant) -> Res> { + fn input(&mut self, d: Datagram, now: Instant) -> Res> { let mut slc = &d[..]; let mut frames = Vec::new(); - qtrace!([self], "input {}", hex(&**d)); + qdebug!([self], "input {}", hex(&**d)); // Handle each packet in the datagram while !slc.is_empty() { - let (packet, remainder) = - match PublicPacket::decode(slc, self.cid_manager.borrow().as_decoder()) { - Ok((packet, remainder)) => (packet, remainder), - Err(e) => { - qinfo!([self], "Garbage packet: {} {}", e, hex(slc)); - self.stats.dropped_rx += 1; - return Ok(frames); - } - }; // TODO(mt) use in place of res, and allow errors + let res = decode_packet_hdr(self.cid_manager.borrow().as_decoder(), slc); + let mut hdr = match res { + Ok(h) => h, + Err(e) => { + qinfo!( + [self], + "Received indecipherable packet header {} {}", + hex(slc), + e + ); + return Ok(frames); // Drop the remainder of the datagram. + } + }; self.stats.packets_rx += 1; - match (packet.packet_type(), &self.state, &self.role) { - (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => { + match (&hdr.tipe, &self.state, &self.role) { + (PacketType::VN(_), State::WaitInitial, Role::Client) => { self.set_state(State::Closed(ConnectionError::Transport( Error::VersionNegotiation, ))); return Err(Error::VersionNegotiation); } - (PacketType::Retry, State::WaitInitial, Role::Client) => { - self.handle_retry(packet)?; + (PacketType::Retry { odcid, token }, State::WaitInitial, Role::Client) => { + self.handle_retry(hdr.scid.as_ref().unwrap(), odcid, token)?; return Ok(frames); } - (PacketType::VersionNegotiation, ..) | (PacketType::Retry, ..) => { - qwarn!("dropping {:?}", packet.packet_type()); - self.stats.dropped_rx += 1; + (PacketType::VN(_), ..) | (PacketType::Retry { .. }, ..) => { + qwarn!("dropping {:?}", hdr.tipe); return Ok(frames); } _ => {} }; + if let Some(version) = hdr.version { + if version != self.version { + qwarn!( + "Dropping packet from version {:x} (self.version={:x})", + hdr.version.unwrap(), + self.version, + ); + return Ok(frames); + } + } + match self.state { State::Init => { qinfo!([self], "Received message while in Init state"); - self.stats.dropped_rx += 1; return Ok(frames); } State::WaitInitial => { qinfo!([self], "Received packet in WaitInitial"); if self.role == Role::Server { - if !packet.is_valid_initial() { - self.stats.dropped_rx += 1; + if !self.is_valid_initial(&hdr) { return Ok(frames); } - self.crypto.states.init(self.role, &packet.dcid()); + self.crypto.create_initial_state(self.role, &hdr.dcid); } } - State::Handshaking | State::Connected | State::Confirmed => { - if !self.is_valid_cid(packet.dcid()) { - qinfo!([self], "Ignoring packet with CID {:?}", packet.dcid()); - self.stats.dropped_rx += 1; + State::Handshaking | State::Connected => { + if !self.is_valid_cid(&hdr.dcid) { + qinfo!([self], "Ignoring packet with CID {:?}", hdr.dcid); return Ok(frames); } - if self.role == Role::Server && packet.packet_type() == PacketType::Handshake { - // Server has received a Handshake packet -> discard Initial keys and states - self.discard_keys(PNSpace::Initial); - } } State::Closing { .. } => { // Don't bother processing the packet. Instead ask to get a // new close frame. - self.state_signaling.close(); + self.flow_mgr.borrow_mut().set_need_close_frame(true); return Ok(frames); } State::Closed(..) => { // Do nothing. - self.stats.dropped_rx += 1; return Ok(frames); } } - qtrace!([self], "Received unverified packet {:?}", packet); + qdebug!([self], "Received unverified packet {:?}", hdr); - let pto = self.loss_recovery.pto(); - let payload = packet.decrypt(&mut self.crypto.states, now + pto); - slc = remainder; - if let Ok(payload) = payload { + let body = self.decrypt_body(&mut hdr, slc); + slc = &slc[hdr.hdr_len + hdr.body_len()..]; + if let Some(body) = body { // TODO(ekr@rtfm.com): Have the server blow away the initial // crypto state if this fails? Otherwise, we will get a panic // on the assert for doesn't exist. // OK, we have a valid packet. self.idle_timeout.on_packet_received(now); - dump_packet( - self, - "-> RX", - payload.packet_type(), - payload.pn(), - &payload[..], - ); - frames.extend(self.process_packet(&payload, now)?); + dump_packet(self, "-> RX", &hdr, &body); + frames.extend(self.process_packet(&hdr, body, now)?); if matches!(self.state, State::WaitInitial) { - self.start_handshake(&packet, &d)?; + self.start_handshake(hdr, &d)?; } self.process_migrations(&d)?; - } else { - // Decryption failure, or not having keys is not fatal. - // If the state isn't available, or we can't decrypt the packet, drop - // the rest of the datagram on the floor, but don't generate an error. - self.stats.dropped_rx += 1; } } Ok(frames) } + fn obtain_epoch_rx_crypto_state(&mut self, epoch: Epoch) -> Option<&mut CryptoDxState> { + if (self.state == State::Handshaking) && (epoch == 3) && (self.role() == Role::Server) { + // We got a packet for epoch 3 but the connection is still in the Handshaking + // state -> discharge the packet. + // On the server side we have keys for epoch 3 before we enter the epoch, + // but we still need to discharge the packet. + None + } else if epoch != 1 { + match self + .crypto + .states + .obtain(self.role, epoch, &self.crypto.tls) + { + Ok(CryptoState { rx, .. }) => rx.as_mut(), + _ => None, + } + } else if self.role == Role::Server { + if let ZeroRttState::AcceptedServer(rx) = &mut self.zero_rtt_state { + return Some(rx); + } + None + } else { + None + } + } + + fn decrypt_body(&mut self, mut hdr: &mut PacketHdr, slc: &[u8]) -> Option> { + // Decryption failure, or not having keys is not fatal. + // If the state isn't available, or we can't decrypt the packet, drop + // the rest of the datagram on the floor, but don't generate an error. + let largest_acknowledged = self + .loss_recovery + .largest_acknowledged_pn(PNSpace::from(hdr.epoch)); + match self.obtain_epoch_rx_crypto_state(hdr.epoch) { + Some(rx) => { + let pn_decoder = PacketNumberDecoder::new(largest_acknowledged); + decrypt_packet(rx, pn_decoder, &mut hdr, slc).ok() + } + _ => None, + } + } + /// Ok(true) if the packet is a duplicate fn process_packet( &mut self, - packet: &DecryptedPacket, + hdr: &PacketHdr, + body: Vec, now: Instant, - ) -> Res> { + ) -> Res> { // TODO(ekr@rtfm.com): Have the server blow away the initial // crypto state if this fails? Otherwise, we will get a panic // on the assert for doesn't exist. @@ -971,71 +944,95 @@ impl Connection { // TODO(ekr@rtfm.com): Filter for valid for this epoch. - let space = PNSpace::from(packet.packet_type()); - if self.acks[space].is_duplicate(packet.pn()) { - qdebug!([self], "Duplicate packet from {} pn={}", space, packet.pn()); + let space = PNSpace::from(hdr.epoch); + if self.acks[space].is_duplicate(hdr.pn) { + qdebug!( + [self], + "Received duplicate packet epoch={} pn={}", + hdr.epoch, + hdr.pn + ); self.stats.dups_rx += 1; return Ok(vec![]); } let mut ack_eliciting = false; - let mut d = Decoder::from(&packet[..]); - let mut consecutive_padding = 0; + let mut d = Decoder::from(&body[..]); #[allow(unused_mut)] let mut frames = Vec::new(); while d.remaining() > 0 { - let mut f = Frame::decode(&mut d)?; - - // Skip padding - while f == Frame::Padding && d.remaining() > 0 { - consecutive_padding += 1; - f = Frame::decode(&mut d)?; - } - if consecutive_padding > 0 { - qdebug!("PADDING frame repeated {} times", consecutive_padding); - consecutive_padding = 0; - } - + let f = decode_frame(&mut d)?; if cfg!(test) { - frames.push((f.clone(), space)); + frames.push((f.clone(), hdr.epoch)); } ack_eliciting |= f.ack_eliciting(); let t = f.get_type(); - let res = self.input_frame(packet.packet_type(), f, now); + let res = self.input_frame(hdr.epoch, f, now); self.capture_error(now, t, res)?; } - self.acks[space].set_received(now, packet.pn(), ack_eliciting); + self.acks[space].set_received(now, hdr.pn, ack_eliciting); Ok(frames) } - fn start_handshake(&mut self, packet: &PublicPacket, d: &Datagram) -> Res<()> { + fn get_zero_rtt_crypto(&mut self) -> Option { + match self.crypto.tls.preinfo() { + Err(_) => None, + Ok(preinfo) => { + match preinfo.early_data_cipher() { + Some(cipher) => { + match self.role { + Role::Client => self.crypto.tls.write_secret(1).map(|ws| { + CryptoDxState::new(CryptoDxDirection::Write, 1, ws, cipher) + }), + Role::Server => self.crypto.tls.read_secret(1).map(|rs| { + CryptoDxState::new(CryptoDxDirection::Read, 1, rs, cipher) + }), + } + } + None => None, + } + } + } + } + + fn start_handshake(&mut self, hdr: PacketHdr, d: &Datagram) -> Res<()> { if self.role == Role::Server { - assert_eq!(packet.packet_type(), PacketType::Initial); + assert!(matches!(hdr.tipe, PacketType::Initial(..))); // A server needs to accept the client's selected CID during the handshake. - self.valid_cids.push(ConnectionId::from(packet.dcid())); + self.valid_cids.push(hdr.dcid.clone()); // Install a path. assert!(self.path.is_none()); - let mut p = Path::new(&d, ConnectionId::from(packet.scid())); + let mut p = Path::new(&d, hdr.scid.unwrap()); p.local_cids .push(self.cid_manager.borrow_mut().generate_cid()); self.path = Some(p); - self.zero_rtt_state = match self.crypto.enable_0rtt(self.role) { - Ok(true) => { - qdebug!([self], "Accepted 0-RTT"); - ZeroRttState::AcceptedServer + // SecretAgentPreinfo::early_data() always returns false for a server, + // but a non-zero maximum tells us if we are accepting 0-RTT. + self.zero_rtt_state = if self.crypto.tls.preinfo()?.max_early_data() > 0 { + match self.get_zero_rtt_crypto() { + Some(cs) => ZeroRttState::AcceptedServer(cs), + None => { + debug_assert!(false, "We must have zero-rtt keys."); + ZeroRttState::Rejected + } } - _ => ZeroRttState::Rejected, + } else { + ZeroRttState::Rejected }; } else { - qdebug!([self], "Changing to use Server CID={}", packet.scid()); + qdebug!( + [self], + "Changing to use Server CID={}", + hdr.scid.as_ref().unwrap() + ); let p = self .path .iter_mut() .find(|p| p.received_on(&d)) .expect("should have a path for sending Initial"); - p.remote_cid = ConnectionId::from(packet.scid()); + p.remote_cid = hdr.scid.unwrap(); } self.set_state(State::Handshaking); Ok(()) @@ -1052,338 +1049,261 @@ impl Connection { } fn output(&mut self, now: Instant) -> Option { - if let Some(mut path) = self.path.take() { - let res = match &self.state { - State::Init - | State::WaitInitial - | State::Handshaking - | State::Connected - | State::Confirmed => self.output_path(&mut path, now), + let mut out = None; + if self.path.is_some() { + match self.output_pkt_for_path(now) { + Ok(res) => { + out = res; + } + Err(e) => { + if !matches!(self.state, State::Closing{..}) { + // An error here causes us to transition to closing. + let err: Result, Error> = Err(e); + self.absorb_error(now, err); + // Rerun to give a chance to send a CONNECTION_CLOSE. + out = match self.output_pkt_for_path(now) { + Ok(x) => x, + Err(e) => { + qwarn!([self], "two output_path errors in a row: {:?}", e); + None + } + }; + } + } + }; + } + out + } + + #[allow(clippy::cognitive_complexity)] + #[allow(clippy::useless_let_if_seq)] + /// Build a datagram, possibly from multiple packets (for different PN + /// spaces) and each containing 1+ frames. + fn output_pkt_for_path(&mut self, now: Instant) -> Res> { + let mut out_bytes = Vec::new(); + let mut needs_padding = false; + let mut close_sent = false; + let path = self + .path + .take() + .expect("we know we have a path because calling fn checked"); + + // Frames for different epochs must go in different packets, but then these + // packets can go in a single datagram + for epoch in 0..NUM_EPOCHS { + let space = PNSpace::from(epoch); + let mut encoder = Encoder::default(); + let mut tokens = Vec::new(); + + // Ensure we have tx crypto state for this epoch, or skip it. + let tx = if epoch == 1 && self.role == Role::Server { + continue; + } else if epoch == 1 { + match &mut self.zero_rtt_state { + ZeroRttState::Sending(tx) => tx, + _ => continue, + } + } else { + match self + .crypto + .states + .obtain(self.role, epoch, &self.crypto.tls) + { + Ok(CryptoState { tx: Some(tx), .. }) => tx, + _ => continue, + } + }; + + let hdr = PacketHdr::new( + 0, + match epoch { + 0 => { + let token = match &self.retry_info { + Some(v) => v.token.clone(), + _ => Vec::new(), + }; + PacketType::Initial(token) + } + 1 => PacketType::ZeroRTT, + 2 => PacketType::Handshake, + 3 => PacketType::Short, + _ => unimplemented!(), // TODO(ekr@rtfm.com): Key Update. + }, + Some(self.version), + path.remote_cid.clone(), + path.local_cids.first().cloned(), + self.loss_recovery.next_pn(space), + epoch, + ); + + let mut ack_eliciting = false; + let mut has_padding = false; + let cong_avail = match self.tx_mode { + TxMode::Normal => usize::try_from(self.loss_recovery.cwnd_avail()).unwrap(), + TxMode::Pto => path.mtu(), // send one packet + }; + let tx_mode = self.tx_mode; + + match &self.state { + State::Init | State::WaitInitial | State::Handshaking | State::Connected => { + loop { + let used = + out_bytes.len() + encoder.len() + hdr.overhead(&tx.aead, path.mtu()); + let remaining = min( + path.mtu().saturating_sub(used), + cong_avail.saturating_sub(used), + ); + if remaining < 2 { + // All useful frames are at least 2 bytes. + break; + } + + // Try to get a frame from frame sources + let mut frame = None; + if self.tx_mode == TxMode::Normal { + frame = self.acks.get_frame(now, epoch); + } + if frame.is_none() { + frame = self.crypto.streams.get_frame(epoch, tx_mode, remaining) + } + if frame.is_none() && self.tx_mode == TxMode::Normal { + frame = self.flow_mgr.borrow_mut().get_frame(epoch, remaining); + } + if frame.is_none() { + frame = self.send_streams.get_frame(epoch, tx_mode, remaining) + } + if frame.is_none() && self.tx_mode == TxMode::Pto { + frame = Some((Frame::Ping, None)); + } + + if let Some((frame, token)) = frame { + ack_eliciting |= frame.ack_eliciting(); + if let Frame::Padding = frame { + has_padding |= true; + } + frame.marshal(&mut encoder); + if let Some(t) = token { + tokens.push(t); + } + + // Pto only ever sends one frame, but it ALWAYS + // sends one + if self.tx_mode == TxMode::Pto { + break; + } + } else { + // No more frames to send. + assert_eq!(self.tx_mode, TxMode::Normal); + break; + } + } + } State::Closing { error, frame_type, msg, .. } => { - let err = error.clone(); - let frame_type = *frame_type; - let msg = msg.clone(); - self.output_close(&path, err, frame_type, msg) - } - State::Closed(_) => Ok(None), - }; - let out = self.absorb_error(now, res).unwrap_or(None); - self.path = Some(path); - out - } else { - None - } - } - - fn build_packet_header( - path: &Path, - space: PNSpace, - encoder: Encoder, - tx: &CryptoDxState, - retry_info: &Option, - ) -> (PacketType, PacketNumber, PacketBuilder) { - let pt = match space { - PNSpace::Initial => PacketType::Initial, - PNSpace::Handshake => PacketType::Handshake, - PNSpace::ApplicationData => { - if tx.is_0rtt() { - PacketType::ZeroRtt - } else { - PacketType::Short + if self.flow_mgr.borrow().need_close_frame() { + // ConnectionClose frame not allowed for 0RTT + if epoch == 1 { + continue; + } + // ConnectionError::Application only allowed at 1RTT + if epoch != 3 && matches!(error, ConnectionError::Application(_)) { + continue; + } + let frame = Frame::ConnectionClose { + error_code: error.clone().into(), + frame_type: *frame_type, + reason_phrase: Vec::from(msg.clone()), + }; + frame.marshal(&mut encoder); + close_sent = true; + } } + State::Closed { .. } => unimplemented!(), } - }; - let mut builder = if pt == PacketType::Short { - qdebug!("Building Short dcid {}", &path.remote_cid,); - PacketBuilder::short(encoder, tx.key_phase(), &path.remote_cid) - } else { - qdebug!( - "Building {:?} dcid {} scid {}", - pt, - &path.remote_cid, - path.local_cids.first().unwrap() - ); - PacketBuilder::long( - encoder, - pt, - &path.remote_cid, - path.local_cids.first().unwrap(), - ) - }; - if pt == PacketType::Initial { - builder.initial_token(if let Some(info) = retry_info { - qtrace!("Initial token {}", hex(&info.token)); - &info.token - } else { - &[] - }); - } - // TODO(mt) work out packet number length based on `4*path CWND/path MTU`. - let pn = tx.next_pn(); - builder.pn(pn, 3); - (pt, pn, builder) - } - - fn output_close( - &mut self, - path: &Path, - error: ConnectionError, - frame_type: FrameType, - msg: String, - ) -> Res> { - if !self.state_signaling.closing() { - return Ok(None); - } - let mut close_sent = false; - let mut encoder = Encoder::with_capacity(path.mtu()); - for space in PNSpace::iter() { - let tx = if let Some(tx_state) = self.crypto.states.tx(*space) { - tx_state - } else { - continue; - }; - - // ConnectionClose frame not allowed for 0RTT. - if tx.is_0rtt() { + assert!(encoder.len() <= path.mtu()); + if encoder.len() == 0 { continue; } - // ConnectionError::Application only allowed at 1RTT. - if *space != PNSpace::ApplicationData - && matches!(error, ConnectionError::Application(_)) - { - continue; - } - let (_, _, mut builder) = Self::build_packet_header(path, *space, encoder, tx, &None); - let frame = Frame::ConnectionClose { - error_code: error.clone().into(), - frame_type, - reason_phrase: Vec::from(msg.clone()), - }; - frame.marshal(&mut builder); - encoder = builder.build(tx)?; - close_sent = true; - } - - if close_sent { - self.state_signaling.close_sent(); - } - Ok(Some(Datagram::new(path.local, path.remote, encoder))) - } - - /// Add frames to the provided builder and - /// return whether any of them were ACK eliciting. - #[allow(clippy::useless_let_if_seq)] - fn add_frames( - &mut self, - builder: &mut PacketBuilder, - space: PNSpace, - tx_mode: TxMode, - limit: usize, - now: Instant, - ) -> (Vec, bool) { - let mut tokens = Vec::new(); - let mut ack_eliciting = false; - // All useful frames are at least 2 bytes. - while builder.len() + 2 < limit { - let remaining = limit - builder.len(); - // Try to get a frame from frame sources - let mut frame = None; - if tx_mode == TxMode::Normal { - frame = self.acks.get_frame(now, space); - } - if frame.is_none() && space == PNSpace::ApplicationData && self.role == Role::Server { - frame = self.state_signaling.send_done(); - } - if frame.is_none() { - frame = self.crypto.streams.get_frame(space, tx_mode, remaining) - } - if frame.is_none() && tx_mode == TxMode::Normal { - frame = self.flow_mgr.borrow_mut().get_frame(space, remaining); - } - if frame.is_none() { - frame = self.send_streams.get_frame(space, tx_mode, remaining); - } - - if let Some((frame, token)) = frame { - ack_eliciting |= frame.ack_eliciting(); - debug_assert_ne!(frame, Frame::Padding); - frame.marshal(builder); - if let Some(t) = token { - tokens.push(t); - } - } else { - if tx_mode == TxMode::Pto { - // Add a PING. - builder.encode_varint(Frame::Ping.get_type()); - ack_eliciting = true; - } - return (tokens, ack_eliciting); - } - - // PTO only ever sends one frame and they always elicit ACKs. - if tx_mode == TxMode::Pto { - debug_assert!(ack_eliciting); - return (tokens, true); - } - } - (tokens, ack_eliciting) - } - - /// Build a datagram, possibly from multiple packets (for different PN - /// spaces) and each containing 1+ frames. - fn output_path(&mut self, path: &mut Path, now: Instant) -> Res> { - let mut needs_padding = false; - - // Check whether we are sending packets in PTO mode. - let (tx_mode, cong_avail, min_pn_space) = - if let Some((min_pto_pn_space, can_send)) = self.loss_recovery.get_pto_state() { - if !can_send { - return Ok(None); - } - (TxMode::Pto, path.mtu(), min_pto_pn_space) - } else { - ( - TxMode::Normal, - usize::try_from(self.loss_recovery.cwnd_avail()).unwrap(), - PNSpace::Initial, - ) - }; - - // Frames for different epochs must go in different packets, but then these - // packets can go in a single datagram - let mut encoder = Encoder::with_capacity(path.mtu()); - for space in PNSpace::iter() { - if *space < min_pn_space { - continue; - } - - // Ensure we have tx crypto state for this epoch, or skip it. - let tx = if let Some(tx_state) = self.crypto.states.tx(*space) { - tx_state - } else { - continue; - }; - - let header_start = encoder.len(); - let (pt, pn, mut builder) = - Self::build_packet_header(path, *space, encoder, tx, &self.retry_info); - let payload_start = builder.len(); - - // Work out how much space we have in the congestion window. - let limit = min(path.mtu(), cong_avail); - if builder.len() + tx.expansion() > limit { - // No space for a packet of this type in the congestion window. - encoder = builder.abort(); - continue; - } - let limit = limit - tx.expansion(); - - let (tokens, ack_eliciting) = - self.add_frames(&mut builder, *space, tx_mode, limit, now); - if builder.is_empty() { - // Nothing to include in this packet. - encoder = builder.abort(); - continue; - } - - dump_packet(self, "TX ->", pt, pn, &builder[payload_start..]); - - qdebug!("Need to send a packet: {:?}", pt); - match pt { + qdebug!("Need to send a packet"); + match epoch { // Packets containing Initial packets need padding. - PacketType::Initial => needs_padding = true, - PacketType::ZeroRtt => (), + 0 => needs_padding = true, + 1 => (), // ...unless they include higher epochs. _ => needs_padding = false, } self.stats.packets_tx += 1; - encoder = builder.build(self.crypto.states.tx(*space).unwrap())?; - assert!(encoder.len() <= path.mtu()); + self.loss_recovery.inc_pn(space); - if tx_mode != TxMode::Pto && ack_eliciting { + let mut packet = encode_packet(tx, &hdr, &encoder); + + if self.tx_mode != TxMode::Pto && ack_eliciting { self.idle_timeout.on_packet_sent(now); } - // Normal packets are in flight if they include PADDING frames, - // but we don't send those. - let in_flight = match tx_mode { + let in_flight = match self.tx_mode { TxMode::Pto => false, - TxMode::Normal => ack_eliciting, + TxMode::Normal => ack_eliciting || has_padding, }; - let sent = SentPacket::new( - now, - ack_eliciting, - tokens, - encoder.len() - header_start, - in_flight, + self.loss_recovery.on_packet_sent( + space, + hdr.pn, + SentPacket::new(now, ack_eliciting, tokens, packet.len(), in_flight), ); - self.loss_recovery.on_packet_sent(*space, pn, sent); - if *space == PNSpace::Handshake && self.role == Role::Client { - // Client can send Handshake packets -> discard Initial keys and states - self.discard_keys(PNSpace::Initial); - } + dump_packet(self, "TX ->", &hdr, &encoder); - if *space == PNSpace::Handshake - && self.role == Role::Server - && self.state == State::Confirmed - { - // We could discard handshake keys in set_state, but we are waiting to send an ack. - self.discard_keys(PNSpace::Handshake); - } + out_bytes.append(&mut packet); } - if encoder.len() == 0 { - assert!(tx_mode != TxMode::Pto); + if close_sent { + self.flow_mgr.borrow_mut().set_need_close_frame(false); + } + + // Sent a probe pkt. Another timeout will re-engage ProbeTimeout mode, + // but otherwise return to honoring CC. + if self.tx_mode == TxMode::Pto { + self.tx_mode = TxMode::Normal; + } + + if out_bytes.is_empty() { + assert!(self.tx_mode != TxMode::Pto); + self.path = Some(path); Ok(None) } else { - debug_assert!(encoder.len() <= path.mtu()); // Pad Initial packets sent by the client to mtu bytes. - let mut packets: Vec = encoder.into(); if self.role == Role::Client && needs_padding { qdebug!([self], "pad Initial to max_datagram_size"); - packets.resize(path.mtu(), 0); + out_bytes.resize(path.mtu(), 0); } - Ok(Some(Datagram::new(path.local, path.remote, packets))) + let ret = Ok(Some(Datagram::new(path.local, path.remote, out_bytes))); + self.path = Some(path); + ret } } - pub fn initiate_key_update(&mut self) -> Res<()> { - if self.state == State::Confirmed { - let la = self - .loss_recovery - .largest_acknowledged_pn(PNSpace::ApplicationData); - qinfo!([self], "Initiating key update"); - self.crypto.states.initiate_key_update(la) - } else { - Err(Error::NotConnected) - } - } - - #[cfg(test)] - pub fn get_epochs(&self) -> (Option, Option) { - self.crypto.states.get_epochs() - } - fn client_start(&mut self, now: Instant) -> Res<()> { qinfo!([self], "client_start"); - self.handshake(now, PNSpace::Initial, None)?; + self.handshake(now, 0, None)?; self.set_state(State::WaitInitial); - self.zero_rtt_state = if self.crypto.enable_0rtt(self.role)? { - qdebug!([self], "Enabled 0-RTT"); - ZeroRttState::Sending - } else { - ZeroRttState::Init - }; + if self.crypto.tls.preinfo()?.early_data() { + qdebug!([self], "Enabling 0-RTT"); + self.zero_rtt_state = match self.get_zero_rtt_crypto() { + Some(cs) => ZeroRttState::Sending(cs), + None => { + debug_assert!(false, "We must have zero-rtt keys."); + ZeroRttState::Rejected + } + }; + } Ok(()) } @@ -1414,42 +1334,37 @@ impl Connection { .conn_increase_max_credit(remote.get_integer(tp_constants::INITIAL_MAX_DATA)); } - fn validate_odcid(&mut self) -> Res<()> { - // Here we drop our Retry state then validate it. - if let Some(info) = self.retry_info.take() { - if info.token.is_empty() { - Ok(()) - } else { - let tph = self.tps.borrow(); - let tp = tph.remote().get_bytes(tp_constants::ORIGINAL_CONNECTION_ID); - if let Some(odcid_tp) = tp { - if odcid_tp[..] == info.odcid[..] { - Ok(()) - } else { - Err(Error::InvalidRetry) - } + fn validate_odcid(&self) -> Res<()> { + if let Some(info) = &self.retry_info { + let tph = self.tps.borrow(); + let tp = tph.remote().get_bytes(tp_constants::ORIGINAL_CONNECTION_ID); + if let Some(odcid_tp) = tp { + if odcid_tp[..] == info.odcid[..] { + Ok(()) } else { Err(Error::InvalidRetry) } + } else { + Err(Error::InvalidRetry) } } else { - debug_assert_eq!(self.role, Role::Server); Ok(()) } } - fn handshake(&mut self, now: Instant, space: PNSpace, data: Option<&[u8]>) -> Res<()> { - qtrace!("Handshake space={} data={:0x?}", space, data); + fn handshake(&mut self, now: Instant, epoch: u16, data: Option<&[u8]>) -> Res<()> { + qdebug!("Handshake epoch={} data={:0x?}", epoch, data); - let rec = data.map(|d| { - qtrace!([self], "Handshake received {:0x?} ", d); - Record { - ct: 22, // TODO(ekr@rtfm.com): Symbolic constants for CT. This is handshake. - epoch: space.into(), - data: d.to_vec(), - } - }); - let try_update = rec.is_some(); + let rec = data + .map(|d| { + qdebug!([self], "Handshake received {:0x?} ", d); + Some(Record { + ct: 22, // TODO(ekr@rtfm.com): Symbolic constants for CT. This is handshake. + epoch, + data: d.to_vec(), + }) + }) + .unwrap_or(None); match self.crypto.tls.handshake_raw(now, rec) { Err(e) => { @@ -1462,22 +1377,20 @@ impl Connection { Ok(msgs) => self.crypto.buffer_records(msgs), } - match self.crypto.tls.state() { - HandshakeState::Authenticated(_) | HandshakeState::InProgress => (), - HandshakeState::AuthenticationPending => self.events.authentication_needed(), - HandshakeState::Complete(_) => { - if !self.state.connected() { - self.set_connected(now)?; - } + if *self.crypto.tls.state() == HandshakeState::AuthenticationPending { + self.events.authentication_needed(); + } else if matches!(self.crypto.tls.state(), HandshakeState::Complete(_)) { + qinfo!([self], "TLS handshake completed"); + + if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() { + qwarn!([self], "No ALPN. Closing connection."); + // 120 = no_application_protocol + return Err(Error::CryptoAlert(120)); } - _ => { - unreachable!("Crypto state should not be new or failed after successful handshake") - } - } - // There is a chance that this could be called less often, but getting the - // conditions right is a little tricky, so call it on every CRYPTO frame. - if try_update { - self.crypto.install_keys(self.role); + + self.validate_odcid()?; + self.set_state(State::Connected); + self.set_initial_limits(); } Ok(()) } @@ -1500,9 +1413,8 @@ impl Connection { } } - fn input_frame(&mut self, ptype: PacketType, frame: Frame, now: Instant) -> Res<()> { - if !frame.is_allowed(ptype) { - qerror!("frame not allowed: {:?} {:?}", frame, ptype); + fn input_frame(&mut self, epoch: Epoch, frame: Frame, now: Instant) -> Res<()> { + if !frame.is_allowed(epoch) { return Err(Error::ProtocolViolation); } match frame { @@ -1519,7 +1431,7 @@ impl Connection { ack_ranges, } => { self.handle_ack( - PNSpace::from(ptype), + epoch, largest_acknowledged, ack_delay, first_ack_range, @@ -1548,20 +1460,19 @@ impl Connection { } } Frame::Crypto { offset, data } => { - let space = PNSpace::from(ptype); - qtrace!( + qdebug!( [self], - "Crypto frame on space={} offset={}, data={:0x?}", - space, + "Crypto frame on epoch={} offset={}, data={:0x?}", + epoch, offset, &data ); - self.crypto.streams.inbound_frame(space, offset, data)?; - if self.crypto.streams.data_ready(space) { + self.crypto.streams.inbound_frame(epoch, offset, data)?; + if self.crypto.streams.data_ready(epoch) { let mut buf = Vec::new(); - let read = self.crypto.streams.read_to_end(space, &mut buf)?; + let read = self.crypto.streams.read_to_end(epoch, &mut buf)?; qdebug!("Read {} bytes", read); - self.handshake(now, space, Some(&buf))?; + self.handshake(now, epoch, Some(&buf))?; } } Frame::NewToken { token } => self.token = Some(token), @@ -1657,7 +1568,7 @@ impl Connection { reason_phrase, } => { let reason_phrase = String::from_utf8_lossy(&reason_phrase); - qerror!( + qinfo!( [self], "ConnectionClose received. Error code: {:?} frame type {:x} reason {}", error_code, @@ -1666,13 +1577,6 @@ impl Connection { ); self.set_state(State::Closed(error_code.into())); } - Frame::HandshakeDone => { - if self.role == Role::Server { - return Err(Error::ProtocolViolation); - } - self.set_state(State::Confirmed); - self.discard_keys(PNSpace::Handshake); - } }; Ok(()) @@ -1692,7 +1596,6 @@ impl Connection { &mut self.recv_streams, &mut self.indexes, ), - RecoveryToken::HandshakeDone => self.state_signaling.handshake_done(), } } } @@ -1700,7 +1603,7 @@ impl Connection { fn handle_ack( &mut self, - space: PNSpace, + epoch: Epoch, largest_acknowledged: u64, ack_delay: u64, first_ack_range: u64, @@ -1709,8 +1612,8 @@ impl Connection { ) -> Res<()> { qinfo!( [self], - "Rx ACK space={}, largest_acked={}, first_ack_range={}, ranges={:?}", - space, + "Rx ACK epoch={}, largest_acked={}, first_ack_range={}, ranges={:?}", + epoch, largest_acknowledged, first_ack_range, ack_ranges @@ -1719,7 +1622,7 @@ impl Connection { let acked_ranges = Frame::decode_ack_frame(largest_acknowledged, first_ack_range, ack_ranges)?; let (acked_packets, lost_packets) = self.loss_recovery.on_ack_received( - space, + PNSpace::from(epoch), largest_acknowledged, acked_ranges, Duration::from_millis(ack_delay), @@ -1734,7 +1637,6 @@ impl Connection { RecoveryToken::Flow(ft) => { self.flow_mgr.borrow_mut().acked(ft, &mut self.send_streams) } - RecoveryToken::HandshakeDone => (), } } } @@ -1744,65 +1646,58 @@ impl Connection { /// When the server rejects 0-RTT we need to drop a bunch of stuff. fn client_0rtt_rejected(&mut self) { - if !matches!(self.zero_rtt_state, ZeroRttState::Sending) { + if !matches!(self.zero_rtt_state, ZeroRttState::Sending(..)) { return; } - qdebug!([self], "0-RTT rejected"); // Tell 0-RTT packets that they were "lost". - let dropped = self.loss_recovery.drop_0rtt(); - self.handle_lost_packets(&dropped); - + // TODO(mt) remove these from "bytes in flight" when we + // have a congestion controller. + for dropped in self.loss_recovery.drop_0rtt() { + for token in dropped.tokens { + match token { + RecoveryToken::Ack(_) => {} + RecoveryToken::Stream(st) => self.send_streams.lost(&st), + RecoveryToken::Crypto(ct) => self.crypto.lost(&ct), + RecoveryToken::Flow(ft) => self.flow_mgr.borrow_mut().lost( + &ft, + &mut self.send_streams, + &mut self.recv_streams, + &mut self.indexes, + ), + } + } + } self.send_streams.clear(); self.recv_streams.clear(); self.indexes = StreamIndexes::new(); - self.crypto.states.discard_0rtt_keys(); self.events.client_0rtt_rejected(); } - fn set_connected(&mut self, now: Instant) -> Res<()> { - qinfo!([self], "TLS connection complete"); - if self.crypto.tls.info().map(SecretAgentInfo::alpn).is_none() { - qwarn!([self], "No ALPN. Closing connection."); - // 120 = no_application_protocol - return Err(Error::CryptoAlert(120)); - } - if self.role == Role::Server { - // Remove the randomized client CID from the list of acceptable CIDs. - assert_eq!(1, self.valid_cids.len()); - self.valid_cids.clear(); - } else { - self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() { - ZeroRttState::AcceptedClient - } else { - self.client_0rtt_rejected(); - ZeroRttState::Rejected - }; - } - - // Setting application keys has to occur after 0-RTT rejection. - let pto = self.loss_recovery.pto(); - self.crypto.install_application_keys(now + pto)?; - self.validate_odcid()?; - self.set_initial_limits(); - self.set_state(State::Connected); - if self.role == Role::Server { - self.state_signaling.handshake_done(); - self.set_state(State::Confirmed); - } - qinfo!([self], "Connection established"); - Ok(()) - } - fn set_state(&mut self, state: State) { if state > self.state { qinfo!([self], "State change from {:?} -> {:?}", self.state, state); self.state = state.clone(); match &self.state { + State::Connected => { + if self.role == Role::Server { + // Remove the randomized client CID from the list of acceptable CIDs. + assert_eq!(1, self.valid_cids.len()); + self.valid_cids.clear(); + } else { + self.zero_rtt_state = + if self.crypto.tls.info().unwrap().early_data_accepted() { + ZeroRttState::AcceptedClient + } else { + self.client_0rtt_rejected(); + ZeroRttState::Rejected + } + } + } State::Closing { .. } => { self.send_streams.clear(); self.recv_streams.clear(); - self.state_signaling.close(); + self.flow_mgr.borrow_mut().set_need_close_frame(true); } State::Closed(..) => { // Equivalent to spec's "draining" state -- never send anything. @@ -1861,13 +1756,9 @@ impl Connection { &mut self, stream_id: StreamId, ) -> Res<(Option<&mut SendStream>, Option<&mut RecvStream>)> { - if !self.state.connected() - && !matches!( - (&self.state, &self.zero_rtt_state), - (State::Handshaking, ZeroRttState::AcceptedServer) - ) - { - return Err(Error::ConnectionState); + match (&self.state, &self.zero_rtt_state) { + (State::Connected, _) | (State::Handshaking, ZeroRttState::AcceptedServer(..)) => (), + _ => return Err(Error::ConnectionState), } // May require creating new stream(s) @@ -1974,7 +1865,7 @@ impl Connection { match self.state { State::Closing { .. } | State::Closed { .. } => return Err(Error::ConnectionState), State::WaitInitial | State::Handshaking => { - if !matches!(self.zero_rtt_state, ZeroRttState::Sending) { + if !matches!(self.zero_rtt_state, ZeroRttState::Sending(..)) { return Err(Error::ConnectionState); } } @@ -2147,19 +2038,93 @@ impl Connection { pub fn next_event(&mut self) -> Option { self.events.next_event() } + + fn check_loss_detection_timeout(&mut self, now: Instant) { + qdebug!([self], "check_loss_timeouts"); + + if matches!(self.loss_recovery_state.mode(), LossRecoveryMode::None) { + // LR not the active timer + return; + } + + if self.loss_recovery_state.callback_time() > Some(now) { + // LR timer, but hasn't expired. + return; + } + + // Timer expired and LR was active timer. + match &mut self.loss_recovery_state.mode() { + LossRecoveryMode::None => unreachable!(), + LossRecoveryMode::LostPackets => { + // Time threshold loss detection + let (pn_space, _) = self + .loss_recovery + .get_earliest_loss_time() + .expect("must be sent packets if in LostPackets mode"); + let packets = self.loss_recovery.detect_lost_packets(pn_space, now); + + qinfo!("lost packets: {}", packets.len()); + for lost in packets { + for token in lost.tokens { + match token { + RecoveryToken::Ack(_) => {} // Do nothing + RecoveryToken::Stream(st) => self.send_streams.lost(&st), + RecoveryToken::Crypto(ct) => self.crypto.lost(&ct), + RecoveryToken::Flow(ft) => self.flow_mgr.borrow_mut().lost( + &ft, + &mut self.send_streams, + &mut self.recv_streams, + &mut self.indexes, + ), + } + } + } + } + LossRecoveryMode::PTO => { + qinfo!( + [self], + "check_loss_detection_timeout -send_one_or_two_packets" + ); + self.loss_recovery.increment_pto_count(); + // TODO + // if (has unacknowledged crypto data): + // RetransmitUnackedCryptoData() + // else if (endpoint is client without 1-RTT keys): + // // Client sends an anti-deadlock packet: Initial is padded + // // to earn more anti-amplification credit, + // // a Handshake packet proves address ownership. + // if (has Handshake keys): + // SendOneHandshakePacket() + // else: + // SendOnePaddedInitialPacket() + // TODO + // SendOneOrTwoPackets() + // PTO. Send new data if available, else retransmit old data. + // If neither is available, send a single PING frame. + + // TODO(agrover): determine if new data is available and if so + // send 2 packets worth + // TODO(agrover): else determine if old data is available and if + // so send 2 packets worth + // TODO(agrover): else send a single PING frame + + self.tx_mode = TxMode::Pto; + } + } + } } impl ::std::fmt::Display for Connection { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "{:?} {:p}", self.role, self as *const Self) + write!(f, "{:?} {:p}", self.role, self as *const Connection) } } #[cfg(test)] mod tests { use super::*; - use crate::cc::{INITIAL_CWND_PKTS, MAX_DATAGRAM_SIZE, MIN_CONG_WINDOW}; use crate::frame::{CloseError, StreamType}; + use crate::recovery::{INITIAL_CWND_PKTS, MAX_DATAGRAM_SIZE, MIN_CONG_WINDOW}; use neqo_common::matches; use std::mem; use test_fixture::{self, assertions, fixture_init, loopback, now}; @@ -2183,11 +2148,11 @@ mod tests { // limit dcid to a constant value to make testing easier let mut modded_path = c.path.take().unwrap(); - let mut modded_cid = modded_path.remote_cid.to_vec(); - modded_cid.truncate(8); - modded_path.remote_cid = ConnectionId::from(&modded_cid[..]); + modded_path.remote_cid.0.truncate(8); + let modded_dcid = modded_path.remote_cid.0.clone(); + assert_eq!(modded_dcid.len(), 8); c.path = Some(modded_path); - c.crypto.states.init(Role::Client, &modded_cid); + c.crypto.create_initial_state(Role::Client, &modded_dcid); c } pub fn default_server() -> Connection { @@ -2305,19 +2270,19 @@ mod tests { let out = client.process(out.dgram(), now()); assert!(out.as_dgram_ref().is_some()); qdebug!("Output={:0x?}", out.as_dgram_ref()); - assert_eq!(*client.state(), State::Connected); qdebug!("---- server: FIN -> ACKS"); let out = server.process(out.dgram(), now()); assert!(out.as_dgram_ref().is_some()); qdebug!("Output={:0x?}", out.as_dgram_ref()); - assert_eq!(*server.state(), State::Confirmed); qdebug!("---- client: ACKS -> 0"); let out = client.process(out.dgram(), now()); assert!(out.as_dgram_ref().is_none()); qdebug!("Output={:0x?}", out.as_dgram_ref()); - assert_eq!(*client.state(), State::Confirmed); + + assert_eq!(*client.state(), State::Connected); + assert_eq!(*server.state(), State::Connected); } #[test] @@ -2402,9 +2367,9 @@ mod tests { qdebug!("---- server"); let out = server.process(out.dgram(), now()); assert!(out.as_dgram_ref().is_some()); - assert_eq!(*server.state(), State::Confirmed); + assert_eq!(*server.state(), State::Connected); qdebug!("Output={:0x?}", out.as_dgram_ref()); - // ACK and HANDSHAKE_DONE + // ACKs // -->> nothing qdebug!("---- client"); @@ -2427,7 +2392,6 @@ mod tests { out = client.process(None, now()); } assert_eq!(datagrams.len(), 4); - assert_eq!(*client.state(), State::Confirmed); qdebug!("---- server"); let mut expect_ack = false; @@ -2437,7 +2401,7 @@ mod tests { qdebug!("Output={:0x?}", out.as_dgram_ref()); expect_ack = !expect_ack; } - assert_eq!(*server.state(), State::Confirmed); + assert_eq!(*server.state(), State::Connected); let mut buf = vec![0; 4000]; @@ -2467,7 +2431,8 @@ mod tests { let mut b = server; let mut datagram = None; let is_done = |c: &mut Connection| match c.state() { - State::Confirmed | State::Closing { .. } | State::Closed(..) => true, + // TODO(mt): Finish on Closed and not Closing. + State::Connected | State::Closing { .. } | State::Closed(..) => true, _ => false, }; while !is_done(a) { @@ -2481,12 +2446,13 @@ mod tests { fn connect(client: &mut Connection, server: &mut Connection) { handshake(client, server); - assert_eq!(*client.state(), State::Confirmed); - assert_eq!(*server.state(), State::Confirmed); + assert_eq!(*client.state(), State::Connected); + assert_eq!(*server.state(), State::Connected); } fn assert_error(c: &Connection, err: ConnectionError) { match c.state() { + // TODO(mt): Finish on Closed and not Closing. State::Closing { error, .. } | State::Closed(error) => { assert_eq!(*error, err); } @@ -2552,19 +2518,17 @@ mod tests { assert!(out.as_dgram_ref().is_none()); qdebug!("Output={:0x?}", out.as_dgram_ref()); - // Four packets total received, 1 of them is a dup and one has been dropped because Initial keys - // are dropped. + // Four packets total received, two of them are dups assert_eq!(4, client.stats().packets_rx); - assert_eq!(1, client.stats().dups_rx); - assert_eq!(1, client.stats().dropped_rx); + assert_eq!(2, client.stats().dups_rx); } fn exchange_ticket(client: &mut Connection, server: &mut Connection) -> Vec { server.send_ticket(now(), &[]).expect("can send ticket"); - let ticket = server.process_output(now()).dgram(); - assert!(ticket.is_some()); - client.process_input(ticket.unwrap(), now()); - assert_eq!(*client.state(), State::Confirmed); + let out = server.process_output(now()); + assert!(out.as_dgram_ref().is_some()); + client.process_input(out.dgram().unwrap(), now()); + assert_eq!(*client.state(), State::Connected); client.resumption_token().expect("should have token") } @@ -2589,7 +2553,7 @@ mod tests { error_code: CloseError::Application(42), .. }, - PNSpace::ApplicationData, + 3, ) )); } @@ -2780,7 +2744,7 @@ mod tests { // The server should receive new stream let server_out = server.process(client_after_reject.dgram(), now()); - assert!(server_out.as_dgram_ref().is_none()); // suppress the ack + assert!(server_out.as_dgram_ref().is_some()); // an ack let recvd_stream_evt = |e| matches!(e, ConnectionEvent::NewStream { .. }); assert!(server.events().any(recvd_stream_evt)); } @@ -2807,36 +2771,11 @@ mod tests { assert!(client.events().any(stream_readable)); } - /// Getting the client and server to reach an idle state is surprisingly hard. - /// The server sends HANDSHAKE_DONE at the end of the handshake, and the client - /// doesn't immediately acknowledge it. - - /// Force the client to send an ACK by having the server send two packets out - /// of order. - fn connect_force_idle(client: &mut Connection, server: &mut Connection) { - connect(client, server); - let p1 = send_something(server, now()); - let p2 = send_something(server, now()); - client.process_input(p2, now()); - // Now the client really wants to send an ACK, but hold it back. - let ack = client.process(Some(p1), now()).dgram(); - assert!(ack.is_some()); - // Now the server has its ACK and both should be idle. - assert_eq!( - server.process(ack, now()), - Output::Callback(LOCAL_IDLE_TIMEOUT) - ); - assert_eq!( - client.process_output(now()), - Output::Callback(LOCAL_IDLE_TIMEOUT) - ); - } - #[test] fn idle_timeout() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -2845,7 +2784,7 @@ mod tests { // Still connected after 59 seconds. Idle timer not reset client.process(None, now + Duration::from_secs(59)); - assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(client.state(), State::Connected)); client.process_timer(now + Duration::from_secs(60)); @@ -2857,7 +2796,7 @@ mod tests { fn idle_send_packet1() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -2873,7 +2812,7 @@ mod tests { // Still connected after 69 seconds because idle timer reset by outgoing // packet client.process(out.dgram(), now + Duration::from_secs(69)); - assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(client.state(), State::Connected)); // Not connected after 70 seconds. client.process_timer(now + Duration::from_secs(70)); @@ -2884,7 +2823,7 @@ mod tests { fn idle_send_packet2() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -2901,7 +2840,7 @@ mod tests { // Still connected after 69 seconds. client.process(None, now + Duration::from_secs(69)); - assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(client.state(), State::Connected)); // Not connected after 70 seconds because timer not reset by second // outgoing packet @@ -2913,7 +2852,7 @@ mod tests { fn idle_recv_packet() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -2933,11 +2872,11 @@ mod tests { // Still connected after 79 seconds because idle timer reset by received // packet client.process(out.dgram(), now + Duration::from_secs(20)); - assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(client.state(), State::Connected)); // Still connected after 79 seconds. client.process_timer(now + Duration::from_secs(79)); - assert!(matches!(client.state(), State::Confirmed)); + assert!(matches!(client.state(), State::Connected)); // Not connected after 80 seconds. client.process_timer(now + Duration::from_secs(80)); @@ -2970,7 +2909,7 @@ mod tests { client .stream_send(stream_id, &[b'a'; RX_STREAM_DATA_WINDOW as usize]) .unwrap(), - usize::try_from(SMALL_MAX_DATA).unwrap() + SMALL_MAX_DATA.try_into().unwrap() ); let evts = client.events().collect::>(); assert_eq!(evts.len(), 2); // SendStreamWritable, StateChange(connected) @@ -3044,7 +2983,7 @@ mod tests { let _ = server.process(client4.dgram(), now()); assert_eq!(*client.state(), State::Connected); - assert_eq!(*server.state(), State::Confirmed); + assert_eq!(*server.state(), State::Connected); } #[test] @@ -3184,7 +3123,7 @@ mod tests { fn pto_works_basic() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -3211,10 +3150,9 @@ mod tests { let frames = server.test_process_input(out.dgram().unwrap(), now + Duration::from_secs(11)); - assert!(matches!( - frames[0], - (Frame::Stream { .. }, PNSpace::ApplicationData) - )); + assert_eq!(frames[0], (Frame::Ping, 0)); + assert_eq!(frames[1], (Frame::Ping, 2)); + assert!(matches!(frames[2], (Frame::Stream { .. }, 3))); } #[test] @@ -3222,7 +3160,7 @@ mod tests { fn pto_works_ping() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -3305,257 +3243,9 @@ mod tests { now + Duration::from_secs(10) + Duration::from_millis(110), ); - assert_eq!(frames[0], (Frame::Ping, PNSpace::ApplicationData)); - } - - #[test] - fn pto_initial() { - let mut now = now(); - - qdebug!("---- client: generate CH"); - let mut client = default_client(); - let pkt1 = client.process(None, now).dgram(); - assert!(pkt1.is_some()); - assert_eq!(pkt1.clone().unwrap().len(), 1232); - - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(120))); - - // Resend initial after PTO. - now += Duration::from_millis(120); - let pkt2 = client.process(None, now).dgram(); - assert!(pkt2.is_some()); - assert_eq!(pkt2.unwrap().len(), 1232); - - let out = client.process(None, now); - // PTO has doubled. - assert_eq!(out, Output::Callback(Duration::from_millis(240))); - - // Server process the first initial pkt. - let mut server = default_server(); - let out = server.process(pkt1, now).dgram(); - assert!(out.is_some()); - - // Client receives ack for the first initial packet as well a Handshake packet. - // After the handshake packet the initial keys and the crypto stream for the initial - // packet number space will be discarded. - // Here only an ack for the Handshake packet will be sent. - now += Duration::from_millis(10); - let out = client.process(out, now).dgram(); - assert!(out.is_some()); - - // We do not have PTO for the resent initial packet any more, because keys are discarded. - // The timeout will be an idle time out of 60s - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_secs(60))); - } - - #[test] - fn pto_handshake() { - let mut now = now(); - // start handshake - let mut client = default_client(); - let mut server = default_server(); - - let pkt = client.process(None, now).dgram(); - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(120))); - - now += Duration::from_millis(10); - let pkt = server.process(pkt, now).dgram(); - - now += Duration::from_millis(10); - let pkt = client.process(pkt, now).dgram(); - - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_secs(60))); - - now += Duration::from_millis(10); - let pkt = server.process(pkt, now).dgram(); - assert!(pkt.is_none()); - - now += Duration::from_millis(10); - client.authenticated(AuthenticationStatus::Ok, now); - - qdebug!("---- client: SH..FIN -> FIN"); - let pkt1 = client.process(None, now).dgram(); - assert!(pkt1.is_some()); - - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(60))); - - // Wait for PTO o expire and resend a handshake packet - now += Duration::from_millis(60); - let pkt2 = client.process(None, now).dgram(); - assert!(pkt2.is_some()); - - // PTO has been doubled. - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(120))); - - now += Duration::from_millis(10); - // Server receives the first packet. - // The output will be a Handshake packet with an ack and a app pn space packet with - // HANDSHAKE_DONE. - let pkt = server.process(pkt1, now).dgram(); - assert!(pkt.is_some()); - - // Check that the second packet(pkt2) has a Handshake and an app pn space packet. - // The server has discarded the Handshake keys already, therefore the handshake packet - // will be dropped. - let dropped_before = server.stats().dropped_rx; - let frames = server.test_process_input(pkt2.unwrap(), now); - assert_eq!(1, server.stats().dropped_rx - dropped_before); - assert!(matches!(frames[0], (Frame::Ping, PNSpace::ApplicationData))); - - now += Duration::from_millis(10); - // Client receive ack for the first packet - let out = client.process(pkt, now); - // Ack delay timer for the packet carrying HANDSHAKE_DONE. - assert_eq!(out, Output::Callback(Duration::from_millis(20))); - - // Let the ack timer expire. - now += Duration::from_millis(20); - let out = client.process(None, now).dgram(); - assert!(out.is_some()); - let out = client.process(None, now); - // The handshake keys are discarded - // Return PTO timer for an app pn space packet (when the Handshake PTO timer has expired, - // a PING in the app pn space has been send as well). - // pto=142.5ms, the PTO packet was sent 40ms ago. The timer will be 102.5ms. - assert_eq!(out, Output::Callback(Duration::from_micros(102_500))); - - // Let PTO expire. We will send a PING only in the APP pn space, the client has discarded - // Handshshake keys. - now += Duration::from_micros(102_500); - let out = client.process(None, now).dgram(); - assert!(out.is_some()); - - now += Duration::from_millis(10); - let frames = server.test_process_input(out.unwrap(), now); - - assert_eq!(frames[0], (Frame::Ping, PNSpace::ApplicationData)); - } - - #[test] - fn test_pto_handshake_and_app_data() { - let mut now = now(); - qdebug!("---- client: generate CH"); - let mut client = default_client(); - let pkt = client.process(None, now); - - now += Duration::from_millis(10); - qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); - let mut server = default_server(); - let pkt = server.process(pkt.dgram(), now); - - now += Duration::from_millis(10); - qdebug!("---- client: cert verification"); - let pkt = client.process(pkt.dgram(), now); - - now += Duration::from_millis(10); - let _pkt = server.process(pkt.dgram(), now); - - now += Duration::from_millis(10); - client.authenticated(AuthenticationStatus::Ok, now); - - assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); - assert_eq!(client.stream_send(2, b"zero").unwrap(), 4); - qdebug!("---- client: SH..FIN -> FIN and 1RTT packet"); - let pkt1 = client.process(None, now).dgram(); - assert!(pkt1.is_some()); - - // Get PTO timer. - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(60))); - - // Wait for PTO o expire and resend a handshake and 1rtt packet - now += Duration::from_millis(60); - let pkt2 = client.process(None, now).dgram(); - assert!(pkt2.is_some()); - - now += Duration::from_millis(10); - let frames = server.test_process_input(pkt2.unwrap(), now); - - assert!(matches!( - frames[0], - (Frame::Crypto { .. }, PNSpace::Handshake) - )); - assert!(matches!( - frames[1], - (Frame::Stream { .. }, PNSpace::ApplicationData) - )); - } - - #[test] - fn test_pto_count_increase_across_spaces() { - let mut now = now(); - qdebug!("---- client: generate CH"); - let mut client = default_client(); - let pkt = client.process(None, now).dgram(); - - now += Duration::from_millis(10); - qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); - let mut server = default_server(); - let pkt = server.process(pkt, now).dgram(); - - now += Duration::from_millis(10); - qdebug!("---- client: cert verification"); - let pkt = client.process(pkt, now).dgram(); - - now += Duration::from_millis(10); - let _pkt = server.process(pkt, now); - - now += Duration::from_millis(10); - client.authenticated(AuthenticationStatus::Ok, now); - - qdebug!("---- client: SH..FIN -> FIN"); - let pkt1 = client.process(None, now).dgram(); - assert!(pkt1.is_some()); - // Get PTO timer. - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(60))); - - now += Duration::from_millis(10); - assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); - assert_eq!(client.stream_send(2, b"zero").unwrap(), 4); - qdebug!("---- client: 1RTT packet"); - let pkt2 = client.process(None, now).dgram(); - assert!(pkt2.is_some()); - - // Get PTO timer. It is the timer for pkt1(handshake pn space). - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(50))); - - // Wait for PTO to expire and resend a handshake and 1rtt packet - now += Duration::from_millis(50); - let pkt3 = client.process(None, now).dgram(); - assert!(pkt3.is_some()); - - // Get PTO timer. It is the timer for pkt2(app pn space). PTO has been doubled. - // pkt2 has been sent 50ms ago (50 + 120 = 170 == 2*85) - let out = client.process(None, now); - assert_eq!(out, Output::Callback(Duration::from_millis(120))); - - // Wait for PTO to expire and resend a handshake and 1rtt packet - now += Duration::from_millis(120); - let pkt4 = client.process(None, now).dgram(); - assert!(pkt4.is_some()); - - now += Duration::from_millis(10); - let frames = server.test_process_input(pkt3.unwrap(), now); - - assert!(matches!( - frames[0], - (Frame::Crypto { .. }, PNSpace::Handshake) - )); - - now += Duration::from_millis(10); - let frames = server.test_process_input(pkt4.unwrap(), now); - assert!(matches!( - frames[1], - (Frame::Stream { .. }, PNSpace::ApplicationData) - )); + assert_eq!(frames[0], (Frame::Ping, 0)); + assert_eq!(frames[1], (Frame::Ping, 2)); + assert_eq!(frames[2], (Frame::Ping, 3)); } #[test] @@ -3563,7 +3253,7 @@ mod tests { fn verify_pkt_honors_mtu() { let mut client = default_client(); let mut server = default_server(); - connect_force_idle(&mut client, &mut server); + connect(&mut client, &mut server); let now = now(); @@ -3636,15 +3326,6 @@ mod tests { ) } - /// This magic number is the size of the Handshake packets sent - /// by the client and acknowledged by the server. - /// As we change how we build packets, or even as NSS changes, - /// this number might be different. The tests that depend on this - /// value could fail as a result of variations, so it's OK to just - /// change this value, but it is good to first understand where the - /// change came from. - const HANDSHAKE_CWND_INCREASE: usize = 631; - #[test] /// Verify initial CWND is honored. fn cc_slow_start() { @@ -3665,12 +3346,16 @@ mod tests { assert_eq!(client.stream_create(StreamType::UniDi).unwrap(), 2); let c_tx_dgrams = send_bytes(&mut client, 2, now); - // Init/Handshake acks have increased cwnd so we actually can + // Init/Handshake acks have increased cwnd by 630 so we actually can // send 11 with the last being shorter + assert_eq!( + c_tx_dgrams.iter().map(|d| d.len()).sum::(), + (INITIAL_CWND_PKTS * MAX_DATAGRAM_SIZE) + 630 + ); assert_eq!(c_tx_dgrams.len(), INITIAL_CWND_PKTS + 1); let (last, rest) = c_tx_dgrams.split_last().unwrap(); assert!(rest.iter().all(|d| d.len() == MAX_DATAGRAM_SIZE)); - assert_eq!(last.len(), HANDSHAKE_CWND_INCREASE); + assert_eq!(last.len(), 630); assert_eq!(client.loss_recovery.cwnd_avail(), 0); } @@ -3694,7 +3379,7 @@ mod tests { assert_eq!(c_tx_dgrams.len(), INITIAL_CWND_PKTS + 1); assert_eq!( c_tx_dgrams.iter().map(|d| d.len()).sum::(), - (INITIAL_CWND_PKTS * MAX_DATAGRAM_SIZE) + HANDSHAKE_CWND_INCREASE + (INITIAL_CWND_PKTS * MAX_DATAGRAM_SIZE) + 630 ); // Server: Receive and generate ack @@ -3712,7 +3397,7 @@ mod tests { largest_acknowledged: INITIAL_CWND_PKTS_U64, .. }, - PNSpace::ApplicationData, + 3, ) )); @@ -3736,7 +3421,7 @@ mod tests { largest_acknowledged: 31, .. }, - PNSpace::ApplicationData, + 3, ) )); @@ -4009,247 +3694,4 @@ mod tests { let c_tx_dgrams = send_bytes(&mut client, 0, now); assert_eq!(c_tx_dgrams.len(), 4); } - - fn check_discarded(peer: &mut Connection, pkt: Datagram, dropped: usize, dups: usize) { - let dropped_before = peer.stats.dropped_rx; - let dups_before = peer.stats.dups_rx; - let out = peer.process(Some(pkt), now()); - assert!(out.as_dgram_ref().is_none()); - assert_eq!(dropped, peer.stats.dropped_rx - dropped_before); - assert_eq!(dups, peer.stats.dups_rx - dups_before); - } - - #[test] - fn discarded_initial_keys() { - qdebug!("---- client: generate CH"); - let mut client = default_client(); - let init_pkt_c = client.process(None, now()).dgram(); - assert!(init_pkt_c.is_some()); - assert_eq!(init_pkt_c.as_ref().unwrap().len(), 1232); - - qdebug!("---- server: CH -> SH, EE, CERT, CV, FIN"); - let mut server = default_server(); - let init_pkt_s = server.process(init_pkt_c.clone(), now()).dgram(); - assert!(init_pkt_s.is_some()); - - qdebug!("---- client: cert verification"); - let out = client.process(init_pkt_s.clone(), now()).dgram(); - assert!(out.is_some()); - - // The client has received handshake packet. It will remove the Initial keys. - // We will check this by processing init_pkt_s a second time. - // The initial packet should be dropped. The packet contains a Handshake packet as well, which - // will be marked as dup. - check_discarded(&mut client, init_pkt_s.unwrap(), 1, 1); - - assert!(maybe_authenticate(&mut client)); - - // The server has not removed the Initial keys yet, because it has not yet received a Handshake - // packet from the client. - // We will check this by processing init_pkt_c a second time. - // The dropped packet is padding. The Initial packet has been mark dup. - check_discarded(&mut server, init_pkt_c.clone().unwrap(), 1, 1); - - qdebug!("---- client: SH..FIN -> FIN"); - let out = client.process(None, now()).dgram(); - assert!(out.is_some()); - - // The server will process the first Handshake packet. - // After this the Initial keys will be dropped. - let out = server.process(out, now()).dgram(); - assert!(out.is_some()); - - // Check that the Initial keys are dropped at the server - // We will check this by processing init_pkt_c a third time. - // The Initial packet has been dropped and padding that follows it. - // There is no dups, everything has been dropped. - check_discarded(&mut server, init_pkt_c.unwrap(), 1, 0); - } - - /// Send something on a stream from `sender` to `receiver`. - /// Return the resulting datagram. - fn send_something(sender: &mut Connection, now: Instant) -> Datagram { - let stream_id = sender.stream_create(StreamType::UniDi).unwrap(); - assert!(sender.stream_send(stream_id, b"data").is_ok()); - assert!(sender.stream_close_send(stream_id).is_ok()); - let dgram = sender.process(None, now).dgram(); - dgram.expect("should have something to send") - } - - /// Send something on a stream from `sender` to `receiver`. - /// Return any ACK that might result. - fn send_and_receive( - sender: &mut Connection, - receiver: &mut Connection, - now: Instant, - ) -> Option { - let dgram = send_something(sender, now); - receiver.process(Some(dgram), now).dgram() - } - - #[test] - fn key_update_client() { - let mut client = default_client(); - let mut server = default_server(); - connect_force_idle(&mut client, &mut server); - let mut now = now(); - - // Both client and server should be idle now. - assert_eq!( - Output::Callback(LOCAL_IDLE_TIMEOUT), - client.process(None, now) - ); - assert_eq!( - Output::Callback(LOCAL_IDLE_TIMEOUT), - server.process(None, now) - ); - assert_eq!(client.get_epochs(), (Some(3), Some(3))); // (write, read) - assert_eq!(server.get_epochs(), (Some(3), Some(3))); - - // TODO(mt) this needs to wait for handshake confirmation, - // but for now, we can do this immediately. - let res = client.initiate_key_update(); - assert!(res.is_ok()); - let res = client.initiate_key_update(); - assert!(res.is_err()); - - // Initiating an update should only increase the write epoch. - assert_eq!( - Output::Callback(LOCAL_IDLE_TIMEOUT), - client.process(None, now) - ); - assert_eq!(client.get_epochs(), (Some(4), Some(3))); - - // Send something to propagate the update. - assert!(send_and_receive(&mut client, &mut server, now).is_none()); - - // The server should now be waiting to discharge read keys. - assert_eq!(server.get_epochs(), (Some(4), Some(3))); - let res = server.process(None, now); - if let Output::Callback(t) = res { - assert!(t < LOCAL_IDLE_TIMEOUT); - } else { - panic!("server should now be waiting to clear keys"); - } - - // Without having had time to purge old keys, more updates are blocked. - // The spec would permits it at this point, but we are more conservative. - assert!(client.initiate_key_update().is_err()); - // The server can't update until it receives an ACK for a packet. - assert!(server.initiate_key_update().is_err()); - - // Waiting now for at least a PTO should cause the server to drop old keys. - // But at this point the client hasn't received a key update from the server. - // It will be stuck with old keys. - now += Duration::from_secs(1); - client.process_timer(now); - assert_eq!(client.get_epochs(), (Some(4), Some(3))); - server.process_timer(now); - assert_eq!(server.get_epochs(), (Some(4), Some(4))); - - // Even though the server has updated, it hasn't received an ACK yet. - assert!(server.initiate_key_update().is_err()); - - // Now get an ACK from the server. - let dgram = send_and_receive(&mut client, &mut server, now); - assert!(dgram.is_some()); - let res = client.process(dgram, now); - // This is the first packet that the client has received from the server - // with new keys, so its read timer just started. - if let Output::Callback(t) = res { - assert!(t < LOCAL_IDLE_TIMEOUT); - } else { - panic!("client should now be waiting to clear keys"); - } - - assert!(client.initiate_key_update().is_err()); - assert_eq!(client.get_epochs(), (Some(4), Some(3))); - // The server can't update until it gets something from the client. - assert!(server.initiate_key_update().is_err()); - - now += Duration::from_secs(1); - client.process_timer(now); - assert_eq!(client.get_epochs(), (Some(4), Some(4))); - } - - #[test] - fn key_update_consecutive() { - let mut client = default_client(); - let mut server = default_server(); - connect(&mut client, &mut server); - let now = now(); - - assert!(server.initiate_key_update().is_ok()); - assert_eq!(server.get_epochs(), (Some(4), Some(3))); - - // Server sends something. - // Send twice and drop the first to induce an ACK from the client. - let _ = send_something(&mut server, now); // Drop this. - - // Another packet from the server will cause the client to ACK and update keys. - let dgram = send_and_receive(&mut server, &mut client, now); - assert!(dgram.is_some()); - assert_eq!(client.get_epochs(), (Some(4), Some(3))); - - // Have the server process the ACK. - if let Output::Callback(_) = server.process(dgram, now) { - assert_eq!(server.get_epochs(), (Some(4), Some(3))); - // Now move the server temporarily into the future so that it - // rotates the keys. Don't do this at home folks. - server.process_timer(now + Duration::from_secs(1)); - assert_eq!(server.get_epochs(), (Some(4), Some(4))); - } else { - panic!("server should have a timer set"); - } - - // Now update keys on the server again. - assert!(server.initiate_key_update().is_ok()); - assert_eq!(server.get_epochs(), (Some(5), Some(4))); - - let dgram = send_something(&mut server, now); - - // However, as the server didn't wait long enough to update again, the - // client hasn't rotated its keys, so the packet gets dropped. - check_discarded(&mut client, dgram, 1, 0); - } - - // Key updates can't be initiated too early. - #[test] - fn key_update_before_confirmed() { - let mut client = default_client(); - assert!(client.initiate_key_update().is_err()); - let mut server = default_server(); - assert!(server.initiate_key_update().is_err()); - - // Client Initial - let dgram = client.process(None, now()).dgram(); - assert!(dgram.is_some()); - assert!(client.initiate_key_update().is_err()); - - // Server Initial + Handshake - let dgram = server.process(dgram, now()).dgram(); - assert!(dgram.is_some()); - assert!(server.initiate_key_update().is_err()); - - // Client Handshake - client.process_input(dgram.unwrap(), now()); - assert!(client.initiate_key_update().is_err()); - - assert!(maybe_authenticate(&mut client)); - assert!(client.initiate_key_update().is_err()); - - let dgram = client.process(None, now()).dgram(); - assert!(dgram.is_some()); - assert!(client.initiate_key_update().is_err()); - - // Server HANDSHAKE_DONE - let dgram = server.process(dgram, now()).dgram(); - assert!(dgram.is_some()); - assert!(server.initiate_key_update().is_ok()); - - // Client receives HANDSHAKE_DONE - let dgram = client.process(dgram, now()).dgram(); - assert!(dgram.is_none()); - assert!(client.initiate_key_update().is_ok()); - } } diff --git a/third_party/rust/neqo-transport/src/crypto.rs b/third_party/rust/neqo-transport/src/crypto.rs index 39406265bd76..96b8dfafa514 100644 --- a/third_party/rust/neqo-transport/src/crypto.rs +++ b/third_party/rust/neqo-transport/src/crypto.rs @@ -5,29 +5,23 @@ // except according to those terms. use std::cell::RefCell; -use std::cmp::max; -use std::mem; -use std::ops::{Index, IndexMut, Range}; use std::rc::Rc; -use std::time::Instant; -use neqo_common::{hex, matches, qdebug, qerror, qinfo, qtrace}; +use neqo_common::{hex, qdebug, qinfo, qtrace}; use neqo_crypto::aead::Aead; use neqo_crypto::hp::HpKey; use neqo_crypto::{ hkdf, Agent, AntiReplay, Cipher, Epoch, RecordList, SymKey, TLS_AES_128_GCM_SHA256, - TLS_AES_256_GCM_SHA384, TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, - TLS_EPOCH_ZERO_RTT, TLS_VERSION_1_3, + TLS_AES_256_GCM_SHA384, TLS_VERSION_1_3, }; use crate::connection::Role; use crate::frame::{Frame, TxMode}; -use crate::packet::PacketNumber; +use crate::packet::{CryptoCtx, PacketNumber}; use crate::recovery::RecoveryToken; use crate::recv_stream::RxStreamOrderer; use crate::send_stream::TxBuffer; use crate::tparams::{TpZeroRttChecker, TransportParametersHandler}; -use crate::tracking::PNSpace; use crate::{Error, Res}; const MAX_AUTH_TAG: usize = 32; @@ -45,7 +39,7 @@ impl Crypto { protocols: &[impl AsRef], tphandler: Rc>, anti_replay: Option<&AntiReplay>, - ) -> Res { + ) -> Res { agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?; agent.enable_ciphers(&[TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384])?; agent.set_alpn(protocols)?; @@ -59,114 +53,49 @@ impl Crypto { )?, } agent.extension_handler(0xffa5, tphandler)?; - Ok(Self { + Ok(Crypto { tls: agent, streams: Default::default(), states: Default::default(), }) } - /// Enable 0-RTT and return `true` if it is enabled successfully. - pub fn enable_0rtt(&mut self, role: Role) -> Res { - let info = self.tls.preinfo()?; - // `info.early_data()` returns false for a server, - // so use `early_data_cipher()` to tell if 0-RTT is enabled. - let cipher = info.early_data_cipher(); - if cipher.is_none() { - return Ok(false); - } - let (dir, secret) = match role { - Role::Client => ( - CryptoDxDirection::Write, - self.tls.write_secret(TLS_EPOCH_ZERO_RTT), - ), - Role::Server => ( - CryptoDxDirection::Read, - self.tls.read_secret(TLS_EPOCH_ZERO_RTT), - ), + // Create the initial crypto state. + pub fn create_initial_state(&mut self, role: Role, dcid: &[u8]) { + const CLIENT_INITIAL_LABEL: &str = "client in"; + const SERVER_INITIAL_LABEL: &str = "server in"; + + qinfo!( + [self], + "Creating initial cipher state role={:?} dcid={}", + role, + hex(dcid) + ); + + let (write_label, read_label) = match role { + Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL), + Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL), }; - let secret = secret.ok_or(Error::KeysNotFound)?; - self.states.set_0rtt_keys(dir, &secret, cipher.unwrap()); - Ok(true) - } - pub fn install_keys(&mut self, role: Role) { - if self.tls.state().is_final() { - return; - } - // These functions only work once, but will usually return KeysNotFound. - // Anything else is unusual and worth logging. - if let Err(e) = self.install_handshake_keys() { - qerror!([self], "Error installing handshake keys: {:?}", e); - } - if role == Role::Server { - if let Err(e) = self.install_application_write_key() { - qerror!([self], "Error installing application write key: {:?}", e); - } - } - } - - fn install_handshake_keys(&mut self) -> Res<()> { - qtrace!([self], "Attempt to install handshake keys"); - let write_secret = if let Some(secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) { - secret - } else { - // No keys is fine. - return Ok(()); - }; - let read_secret = self - .tls - .read_secret(TLS_EPOCH_HANDSHAKE) - .ok_or(Error::KeysNotFound)?; - let cipher = match self.tls.info() { - None => self.tls.preinfo()?.cipher_suite(), - Some(info) => Some(info.cipher_suite()), - } - .ok_or(Error::KeysNotFound)?; - self.states - .set_handshake_keys(&write_secret, &read_secret, cipher); - qdebug!([self], "Handshake keys installed"); - Ok(()) - } - - fn install_application_write_key(&mut self) -> Res<()> { - qtrace!([self], "Attempt to install application write key"); - if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) { - self.states.set_application_write_key(secret)?; - qdebug!([self], "Application write key installed"); - } - Ok(()) - } - - pub fn install_application_keys(&mut self, expire_0rtt: Instant) -> Res<()> { - if let Err(e) = self.install_application_write_key() { - if e != Error::KeysNotFound { - return Err(e); - } - } - let read_secret = self - .tls - .read_secret(TLS_EPOCH_APPLICATION_DATA) - .ok_or(Error::KeysNotFound)?; - self.states - .set_application_read_key(read_secret, expire_0rtt)?; - qdebug!([self], "application read keys installed"); - Ok(()) + self.states.states[0] = Some(CryptoState { + tx: CryptoDxState::new_initial(CryptoDxDirection::Write, write_label, dcid), + rx: CryptoDxState::new_initial(CryptoDxDirection::Read, read_label, dcid), + }); } /// Buffer crypto records for sending. pub fn buffer_records(&mut self, records: RecordList) { for r in records { assert_eq!(r.ct, 22); - qtrace!([self], "Adding CRYPTO data {:?}", r); - self.streams.send(PNSpace::from(r.epoch), &r.data); + qdebug!([self], "Adding CRYPTO data {:?}", r); + self.streams.send(r.epoch, &r.data); } } pub fn acked(&mut self, token: CryptoRecoveryToken) { qinfo!( - "Acked crypto frame space={} offset={} length={}", - token.space, + "Acked crypto frame epoch={} offset={} length={}", + token.epoch, token.offset, token.length ); @@ -175,18 +104,13 @@ impl Crypto { pub fn lost(&mut self, token: &CryptoRecoveryToken) { qinfo!( - "Lost crypto frame space={} offset={} length={}", - token.space, + "Lost crypto frame epoch={} offset={} length={}", + token.epoch, token.offset, token.length ); self.streams.lost(token); } - - pub fn discard(&mut self, space: PNSpace) { - self.states.discard(space); - self.streams.discard(space); - } } impl ::std::fmt::Display for Crypto { @@ -195,7 +119,7 @@ impl ::std::fmt::Display for Crypto { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug)] pub enum CryptoDxDirection { Read, Write, @@ -203,20 +127,10 @@ pub enum CryptoDxDirection { #[derive(Debug)] pub struct CryptoDxState { - direction: CryptoDxDirection, - /// The epoch of this crypto state. This initially tracks TLS epochs - /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application. - /// But we don't need to keep that, and QUIC isn't limited in how - /// many times keys can be updated, so we don't use `u16` for this. - epoch: usize, - aead: Aead, - hpkey: HpKey, - /// This tracks the range of packet numbers that have been seen. This allows - /// for verifying that packet numbers before a key update are strictly lower - /// than packet numbers after a key update. - used_pn: Range, - /// This is the minimum allowed. - min_pn: PacketNumber, + pub(crate) direction: CryptoDxDirection, + pub(crate) epoch: Epoch, + pub(crate) aead: Aead, + pub(crate) hpkey: HpKey, } impl CryptoDxState { @@ -225,24 +139,26 @@ impl CryptoDxState { epoch: Epoch, secret: &SymKey, cipher: Cipher, - ) -> Self { + ) -> CryptoDxState { qinfo!( "Making {:?} {} CryptoDxState, cipher={}", direction, epoch, cipher ); - Self { + CryptoDxState { direction, - epoch: usize::from(epoch), + epoch, aead: Aead::new(TLS_VERSION_1_3, cipher, secret, "quic ").unwrap(), hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, "quic hp").unwrap(), - used_pn: 0..0, - min_pn: 0, } } - pub fn new_initial(direction: CryptoDxDirection, label: &str, dcid: &[u8]) -> Self { + pub fn new_initial( + direction: CryptoDxDirection, + label: &str, + dcid: &[u8], + ) -> Option { const INITIAL_SALT: &[u8] = &[ 0xc3, 0xee, 0xf7, 0x12, 0xc7, 0x2e, 0xbb, 0x5a, 0x11, 0xa7, 0xd2, 0x43, 0x2b, 0xb4, 0x63, 0x65, 0xbe, 0xf9, 0xf5, 0x02, @@ -265,112 +181,34 @@ impl CryptoDxState { let secret = hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap(); - Self::new(direction, TLS_EPOCH_INITIAL, &secret, cipher) + Some(CryptoDxState::new(direction, 0, &secret, cipher)) } +} - pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self { - let pn = self.next_pn(); - Self { - direction: self.direction, - epoch: self.epoch + 1, - aead: Aead::new(TLS_VERSION_1_3, cipher, next_secret, "quic ").unwrap(), - hpkey: self.hpkey.clone(), - used_pn: pn..pn, - min_pn: pn, - } - } - - #[must_use] - pub fn is_0rtt(&self) -> bool { - self.epoch == usize::from(TLS_EPOCH_ZERO_RTT) - } - - #[must_use] - pub fn key_phase(&self) -> bool { - // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ... - self.epoch & 1 != 1 - } - - /// This is a continuation of a previous, so adjust the range accordingly. - /// Fail if the two ranges overlap. Do nothing if the directions don't match. - pub fn continuation(&mut self, prev: &Self) -> Res<()> { - debug_assert_eq!(self.direction, prev.direction); - let next = prev.next_pn(); - self.min_pn = next; - // TODO(mt) use Range::is_empty() when available - if self.used_pn.start == self.used_pn.end { - self.used_pn = next..next; - Ok(()) - } else if prev.used_pn.end > self.used_pn.start { - qdebug!( - [self], - "Found packet with too new packet number {} > {}, compared to {}", - self.used_pn.start, - prev.used_pn.end, - prev, - ); - Err(Error::PacketNumberOverlap) - } else { - self.used_pn.start = next; - Ok(()) - } - } - - /// Mark a packet number as used. If this is too low, reject it. - /// Note that this won't catch a value that is too high if packets protected with - /// old keys are received after a key update. That needs to be caught elsewhere. - pub fn used(&mut self, pn: PacketNumber) -> Res<()> { - if pn < self.min_pn { - qdebug!( - [self], - "Found packet with too old packet number: {} < {}", - pn, - self.min_pn - ); - return Err(Error::PacketNumberOverlap); - } - if self.used_pn.start == self.used_pn.end { - self.used_pn.start = pn; - } - self.used_pn.end = max(pn + 1, self.used_pn.end); - Ok(()) - } - - #[must_use] - pub fn needs_update(&self) -> bool { - // Only initiate a key update if we have processed exactly one packet - // and we are in an epoch greater than 3. - self.used_pn.start + 1 == self.used_pn.end - && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA) - } - - #[must_use] - pub fn can_update(&self, largest_acknowledged: Option) -> bool { - if let Some(la) = largest_acknowledged { - self.used_pn.contains(&la) - } else { - // If we haven't received any acknowledgments, it's OK to update - // the first application data epoch. - self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA) - } - } - - pub fn compute_mask(&self, sample: &[u8]) -> Res> { +impl CryptoCtx for CryptoDxState { + fn compute_mask(&self, sample: &[u8]) -> Res> { let mask = self.hpkey.mask(sample)?; - qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask)); + qdebug!("HP sample={} mask={}", hex(sample), hex(&mask)); Ok(mask) } - #[must_use] - pub fn next_pn(&self) -> PacketNumber { - self.used_pn.end + fn aead_decrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { + qinfo!( + [self], + "aead_decrypt pn={} hdr={} body={}", + pn, + hex(hdr), + hex(body) + ); + let mut out = vec![0; body.len()]; + let res = self.aead.decrypt(pn, hdr, body, &mut out)?; + Ok(res.to_vec()) } - pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { - debug_assert_eq!(self.direction, CryptoDxDirection::Write); - qtrace!( + fn aead_encrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { + qdebug!( [self], - "encrypt pn={} hdr={} body={}", + "aead_encrypt pn={} hdr={} body={}", pn, hex(hdr), hex(body) @@ -380,38 +218,10 @@ impl CryptoDxState { let mut out = vec![0; size]; let res = self.aead.encrypt(pn, hdr, body, &mut out)?; - qtrace!([self], "encrypt ct={}", hex(res)); - debug_assert_eq!(pn, self.next_pn()); - self.used(pn)?; + qdebug!([self], "aead_encrypt ct={}", hex(res),); + Ok(res.to_vec()) } - - #[must_use] - pub fn expansion(&self) -> usize { - self.aead.expansion() - } - - pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { - debug_assert_eq!(self.direction, CryptoDxDirection::Read); - qtrace!( - [self], - "decrypt pn={} hdr={} body={}", - pn, - hex(hdr), - hex(body) - ); - let mut out = vec![0; body.len()]; - let res = self.aead.decrypt(pn, hdr, body, &mut out)?; - self.used(pn)?; - Ok(res.to_vec()) - } - - #[cfg(test)] - pub(crate) fn test_default() -> Self { - // This matches the value in packet.rs - const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; - Self::new_initial(CryptoDxDirection::Write, "server in", CLIENT_CID) - } } impl std::fmt::Display for CryptoDxState { @@ -420,383 +230,15 @@ impl std::fmt::Display for CryptoDxState { } } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct CryptoState { - tx: CryptoDxState, - rx: CryptoDxState, -} - -impl Index for CryptoState { - type Output = CryptoDxState; - - fn index(&self, dir: CryptoDxDirection) -> &Self::Output { - match dir { - CryptoDxDirection::Read => &self.rx, - CryptoDxDirection::Write => &self.tx, - } - } -} - -impl IndexMut for CryptoState { - fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output { - match dir { - CryptoDxDirection::Read => &mut self.rx, - CryptoDxDirection::Write => &mut self.tx, - } - } -} - -/// `CryptoDxAppData` wraps the state necessary for one direction of application data keys. -/// This includes the secret needed to generate the next set of keys. -#[derive(Debug)] -pub(crate) struct CryptoDxAppData { - dx: CryptoDxState, - cipher: Cipher, - // Not the secret used to create `self.dx`, but the one needed for the next iteration. - next_secret: SymKey, -} - -impl CryptoDxAppData { - pub fn new(dir: CryptoDxDirection, secret: SymKey, cipher: Cipher) -> Res { - Ok(Self { - dx: CryptoDxState::new(dir, TLS_EPOCH_APPLICATION_DATA, &secret, cipher), - cipher, - next_secret: Self::update_secret(cipher, &secret)?, - }) - } - - fn update_secret(cipher: Cipher, secret: &SymKey) -> Res { - let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?; - Ok(next) - } - - pub fn next(&self) -> Res { - if self.dx.epoch == usize::max_value() { - // Guard against too many key updates. - return Err(Error::KeysNotFound); - } - let next_secret = Self::update_secret(self.cipher, &self.next_secret)?; - Ok(Self { - dx: self.dx.next(&next_secret, self.cipher), - cipher: self.cipher, - next_secret, - }) - } + pub tx: Option, + pub rx: Option, } #[derive(Debug, Default)] pub struct CryptoStates { - initial: Option, - handshake: Option, - zero_rtt: Option, // One direction only! - cipher: Cipher, - app_write: Option, - app_read: Option, - app_read_next: Option, - // If this is set, then we have noticed a genuine update. - // Once this time passes, we should switch in new keys. - read_update_time: Option, -} - -impl CryptoStates { - fn select_or_0rtt<'a>( - app: Option<&'a mut CryptoDxAppData>, - zero_rtt: Option<&'a mut CryptoDxState>, - dir: CryptoDxDirection, - ) -> Option<&'a mut CryptoDxState> { - app.map(|a| &mut a.dx) - .or_else(|| zero_rtt.filter(|z| z.direction == dir)) - } - - pub fn tx<'a>(&'a mut self, space: PNSpace) -> Option<&'a mut CryptoDxState> { - let tx = |x: &'a mut Option| x.as_mut().map(|dx| &mut dx.tx); - match space { - PNSpace::Initial => tx(&mut self.initial), - PNSpace::Handshake => tx(&mut self.handshake), - PNSpace::ApplicationData => Self::select_or_0rtt( - self.app_write.as_mut(), - self.zero_rtt.as_mut(), - CryptoDxDirection::Write, - ), - } - } - - pub fn rx_hp(&mut self, space: PNSpace) -> Option<&mut CryptoDxState> { - match space { - PNSpace::ApplicationData => Self::select_or_0rtt( - self.app_read.as_mut(), - self.zero_rtt.as_mut(), - CryptoDxDirection::Read, - ), - _ => self.rx(space, false), - } - } - - pub fn rx<'a>(&'a mut self, space: PNSpace, key_phase: bool) -> Option<&'a mut CryptoDxState> { - let rx = |x: &'a mut Option| x.as_mut().map(|dx| &mut dx.rx); - match space { - PNSpace::Initial => rx(&mut self.initial), - PNSpace::Handshake => rx(&mut self.handshake), - PNSpace::ApplicationData => { - let app = if let Some(arn) = &self.app_read_next { - if arn.dx.key_phase() == key_phase { - self.app_read_next.as_mut() - } else { - self.app_read.as_mut() - } - } else { - self.app_read.as_mut() - }; - Self::select_or_0rtt(app, self.zero_rtt.as_mut(), CryptoDxDirection::Read) - } - } - } - - /// Create the initial crypto state. - pub fn init(&mut self, role: Role, dcid: &[u8]) { - const CLIENT_INITIAL_LABEL: &str = "client in"; - const SERVER_INITIAL_LABEL: &str = "server in"; - - qinfo!( - [self], - "Creating initial cipher state role={:?} dcid={}", - role, - hex(dcid) - ); - - let (write_label, read_label) = match role { - Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL), - Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL), - }; - - let mut initial = CryptoState { - tx: CryptoDxState::new_initial(CryptoDxDirection::Write, write_label, dcid), - rx: CryptoDxState::new_initial(CryptoDxDirection::Read, read_label, dcid), - }; - if let Some(prev) = &self.initial { - qinfo!( - [self], - "Continue packet numbers for initial after retry (write is {:?})", - prev.rx.used_pn, - ); - initial.tx.continuation(&prev.tx).unwrap(); - } - self.initial = Some(initial); - } - - pub fn set_0rtt_keys(&mut self, dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher) { - self.zero_rtt = Some(CryptoDxState::new(dir, TLS_EPOCH_ZERO_RTT, secret, cipher)); - } - - pub fn discard(&mut self, space: PNSpace) { - match space { - PNSpace::Initial => self.initial = None, - PNSpace::Handshake => self.handshake = None, - PNSpace::ApplicationData => panic!("Can't drop application data keys"), - } - } - - pub fn discard_0rtt_keys(&mut self) { - assert!( - self.app_read.is_none(), - "Can't discard 0-RTT after setting application keys" - ); - self.zero_rtt = None; - } - - pub fn set_handshake_keys( - &mut self, - write_secret: &SymKey, - read_secret: &SymKey, - cipher: Cipher, - ) { - self.cipher = cipher; - self.handshake = Some(CryptoState { - tx: CryptoDxState::new( - CryptoDxDirection::Write, - TLS_EPOCH_HANDSHAKE, - write_secret, - cipher, - ), - rx: CryptoDxState::new( - CryptoDxDirection::Read, - TLS_EPOCH_HANDSHAKE, - read_secret, - cipher, - ), - }); - } - - pub fn set_application_write_key(&mut self, secret: SymKey) -> Res<()> { - debug_assert!(self.app_write.is_none()); - debug_assert_ne!(self.cipher, 0); - let mut app = CryptoDxAppData::new(CryptoDxDirection::Write, secret, self.cipher)?; - if let Some(z) = &self.zero_rtt { - if z.direction == CryptoDxDirection::Write { - app.dx.continuation(z)?; - } - } - self.zero_rtt = None; - self.app_write = Some(app); - Ok(()) - } - - pub fn set_application_read_key(&mut self, secret: SymKey, expire_0rtt: Instant) -> Res<()> { - debug_assert!(self.app_write.is_some(), "should have write keys installed"); - debug_assert!(self.app_read.is_none()); - let mut app = CryptoDxAppData::new(CryptoDxDirection::Read, secret, self.cipher)?; - if let Some(z) = &self.zero_rtt { - if z.direction == CryptoDxDirection::Read { - app.dx.continuation(z)?; - } - self.read_update_time = Some(expire_0rtt); - } - self.app_read_next = Some(app.next()?); - self.app_read = Some(app); - Ok(()) - } - - /// Update the write keys. - pub fn initiate_key_update(&mut self, largest_acknowledged: Option) -> Res<()> { - // Only update if we are able to. We can only do this if we have - // received an acknowledgement for a packet in the current phase. - // Also, skip this if we are waiting for read keys on the existing - // key update to be rolled over. - let write = &self.app_write.as_ref().unwrap().dx; - if write.can_update(largest_acknowledged) && self.read_update_time.is_none() { - // This call additionally checks that we don't advance to the next - // epoch while a key update is in progress. - if self.maybe_update_write()? { - Ok(()) - } else { - qdebug!([self], "Write keys already updated"); - Err(Error::KeyUpdateBlocked) - } - } else { - qdebug!([self], "Waiting for ACK or blocked on read key timer"); - Err(Error::KeyUpdateBlocked) - } - } - - /// Try to update, and return true if it happened. - fn maybe_update_write(&mut self) -> Res { - // Update write keys. But only do so if the write keys are not already - // ahead of the read keys. If we initiated the key update, the write keys - // will already be ahead. - debug_assert!(self.read_update_time.is_none()); - let write = &self.app_write.as_ref().unwrap().dx; - let read = &self.app_read.as_ref().unwrap().dx; - if write.epoch == read.epoch { - qdebug!([self], "Updating write keys to epoch={}", write.epoch + 1); - self.app_write = Some(self.app_write.as_ref().unwrap().next()?); - Ok(true) - } else { - Ok(false) - } - } - - fn has_0rtt_read(&self) -> bool { - self.zero_rtt - .as_ref() - .filter(|z| z.direction == CryptoDxDirection::Read) - .is_some() - } - - /// Prepare to update read keys. This doesn't happen immediately as - /// we want to ensure that we can continue to receive any delayed - /// packets that use the old keys. So we just set a timer. - pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> { - // If we received a key update, then we assume that the peer has - // acknowledged a packet we sent in this epoch. It's OK to do that - // because they aren't allowed to update without first having received - // something from us. If the ACK isn't in the packet that triggered this - // key update, it must be in some other packet they have sent. - let _ = self.maybe_update_write()?; - - // We shouldn't have 0-RTT keys at this point, but if we do, dump them. - debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read()); - if self.has_0rtt_read() { - self.zero_rtt = None; - } - self.read_update_time = Some(expiration); - Ok(()) - } - - #[must_use] - pub fn update_time(&self) -> Option { - self.read_update_time - } - - /// Check if time has passed for updating key update parameters. - /// If it has, then swap keys over and allow more key updates to be initiated. - /// This is also used to discard 0-RTT read keys at the server in the same way. - pub fn check_key_update(&mut self, now: Instant) -> Res<()> { - if let Some(expiry) = self.read_update_time { - // If enough time has passed, then install new keys and clear the timer. - if now >= expiry { - if self.has_0rtt_read() { - qtrace!([self], "Discarding 0-RTT keys"); - self.zero_rtt = None; - } else { - qtrace!([self], "Rotating read keys"); - mem::swap(&mut self.app_read, &mut self.app_read_next); - self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?); - } - self.read_update_time = None; - } - } - Ok(()) - } - - /// Get the current/highest epoch. This returns (write, read) epochs. - #[cfg(test)] - pub fn get_epochs(&self) -> (Option, Option) { - let to_epoch = |app: &Option| app.as_ref().map(|a| a.dx.epoch); - (to_epoch(&self.app_write), to_epoch(&self.app_read)) - } - - /// While we are awaiting the completion of a key update, we might receive - /// valid packets that are protected with old keys. We need to ensure that - /// these don't carry packet numbers higher than those in packets protected - /// with the newer keys. To ensure that, this is called after every decryption. - pub fn check_pn_overlap(&mut self) -> Res<()> { - // We only need to do the check while we are waiting for read keys to be updated. - if self.read_update_time.is_some() { - qtrace!([self], "Checking for PN overlap"); - let next_dx = &mut self.app_read_next.as_mut().unwrap().dx; - next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?; - } - Ok(()) - } - - /// Make some state for removing protection in tests. - #[cfg(test)] - pub(crate) fn test_default() -> Self { - let read = || { - let mut dx = CryptoDxState::test_default(); - dx.direction = CryptoDxDirection::Read; - dx - }; - let app_read = || CryptoDxAppData { - dx: read(), - cipher: TLS_AES_128_GCM_SHA256, - next_secret: hkdf::import_key(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &[0xaa; 32]) - .unwrap(), - }; - Self { - initial: Some(CryptoState { - tx: CryptoDxState::test_default(), - rx: read(), - }), - handshake: None, - zero_rtt: None, - cipher: TLS_AES_128_GCM_SHA256, - app_write: None, - app_read: Some(app_read()), - app_read_next: Some(app_read()), - read_update_time: None, - } - } + pub states: [Option; 4], } impl std::fmt::Display for CryptoStates { @@ -805,114 +247,123 @@ impl std::fmt::Display for CryptoStates { } } -#[derive(Debug, Default)] -pub struct CryptoStream { - tx: TxBuffer, - rx: RxStreamOrderer, +impl CryptoStates { + // Get a crypto state, making it if necessary, otherwise return an error. + pub fn obtain(&mut self, role: Role, epoch: Epoch, tls: &Agent) -> Res<&mut CryptoState> { + #[cfg(debug_assertions)] + let label = format!("{}", self); + #[cfg(not(debug_assertions))] + let label = ""; + + let cs = &mut self.states[epoch as usize]; + if cs.is_none() { + qtrace!([label], "Build crypto state for epoch {}", epoch); + assert!(epoch != 0); // This state is made directly. + + let cipher = match (epoch, tls.info()) { + (1, _) => tls.preinfo()?.early_data_cipher(), + (_, None) => tls.preinfo()?.cipher_suite(), + (_, Some(info)) => Some(info.cipher_suite()), + } + .ok_or_else(|| { + qdebug!([label], "cipher info not available yet"); + Error::KeysNotFound + })?; + + let rx = tls + .read_secret(epoch) + .map(|rs| CryptoDxState::new(CryptoDxDirection::Read, epoch, rs, cipher)); + let tx = tls + .write_secret(epoch) + .map(|ws| CryptoDxState::new(CryptoDxDirection::Write, epoch, ws, cipher)); + + // Validate the key setup. + match (&rx, &tx, role, epoch) { + (None, Some(_), Role::Client, 1) + | (Some(_), None, Role::Server, 1) + | (Some(_), Some(_), _, _) => {} + (None, None, _, _) => { + qdebug!([label], "Keying material not available for epoch {}", epoch); + return Err(Error::KeysNotFound); + } + _ => panic!("bad configuration of keys"), + } + + *cs = Some(CryptoState { rx, tx }); + } + + Ok(cs.as_mut().unwrap()) + } } -#[derive(Debug)] -#[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408 -pub enum CryptoStreams { - Initial { - initial: CryptoStream, - handshake: CryptoStream, - application: CryptoStream, - }, - Handshake { - handshake: CryptoStream, - application: CryptoStream, - }, - ApplicationData { - application: CryptoStream, - }, +#[derive(Debug, Default)] +pub struct CryptoStream { + pub tx: TxBuffer, + pub rx: RxStreamOrderer, +} + +#[derive(Debug, Default)] +pub struct CryptoStreams { + streams: [CryptoStream; 4], } impl CryptoStreams { - pub fn discard(&mut self, space: PNSpace) { - match space { - PNSpace::Initial => { - if let Self::Initial { - handshake, - application, - .. - } = self - { - *self = Self::Handshake { - handshake: mem::take(handshake), - application: mem::take(application), - }; - } - } - PNSpace::Handshake => { - if let Self::Handshake { application, .. } = self { - *self = Self::ApplicationData { - application: mem::take(application), - }; - } else if matches!(self, Self::Initial {..}) { - panic!("Discarding handshake before initial discarded"); - } - } - PNSpace::ApplicationData => panic!("Discarding application data crypto streams"), - } + pub fn send(&mut self, epoch: u16, data: &[u8]) { + self.streams[epoch as usize].tx.send(data); } - pub fn send(&mut self, space: PNSpace, data: &[u8]) { - self[space].tx.send(data); + pub fn inbound_frame(&mut self, epoch: u16, offset: u64, data: Vec) -> Res<()> { + self.streams[epoch as usize].rx.inbound_frame(offset, data) } - pub fn inbound_frame(&mut self, space: PNSpace, offset: u64, data: Vec) -> Res<()> { - self[space].rx.inbound_frame(offset, data) + pub fn data_ready(&self, epoch: u16) -> bool { + self.streams[epoch as usize].rx.data_ready() } - pub fn data_ready(&self, space: PNSpace) -> bool { - self[space].rx.data_ready() - } - - pub fn read_to_end(&mut self, space: PNSpace, buf: &mut Vec) -> Res { - self[space].rx.read_to_end(buf) + pub fn read_to_end(&mut self, epoch: u16, buf: &mut Vec) -> Res { + self.streams[epoch as usize].rx.read_to_end(buf) } pub fn acked(&mut self, token: CryptoRecoveryToken) { - self[token.space] + self.streams[token.epoch as usize] .tx .mark_as_acked(token.offset, token.length) } pub fn lost(&mut self, token: &CryptoRecoveryToken) { - self[token.space] + self.streams[token.epoch as usize] .tx .mark_as_lost(token.offset, token.length) } - pub fn sent(&mut self, space: PNSpace, offset: u64, length: usize) { - self[space].tx.mark_as_sent(offset, length) + pub fn sent(&mut self, epoch: u16, offset: u64, length: usize) { + self.streams[epoch as usize].tx.mark_as_sent(offset, length) } - pub fn next_bytes(&self, space: PNSpace, mode: TxMode) -> Option<(u64, &[u8])> { - self[space].tx.next_bytes(mode) + pub fn next_bytes(&self, epoch: u16, mode: TxMode) -> Option<(u64, &[u8])> { + self.streams[epoch as usize].tx.next_bytes(mode) } pub fn get_frame( &mut self, - space: PNSpace, + epoch: u16, mode: TxMode, remaining: usize, ) -> Option<(Frame, Option)> { - if let Some((offset, data)) = self.next_bytes(space, mode) { + if let Some((offset, data)) = self.next_bytes(epoch, mode) { let (frame, length) = Frame::new_crypto(offset, data, remaining); - self.sent(space, offset, length); + self.sent(epoch, offset, length); qdebug!( - "Emitting crypto frame space={}, offset={}, len={}", - space, + "Emitting crypto frame epoch={}, offset={}, len={}", + epoch, offset, length ); Some(( frame, Some(RecoveryToken::Crypto(CryptoRecoveryToken { - space, + epoch, offset, length, })), @@ -923,64 +374,9 @@ impl CryptoStreams { } } -impl Default for CryptoStreams { - fn default() -> Self { - Self::Initial { - initial: CryptoStream::default(), - handshake: CryptoStream::default(), - application: CryptoStream::default(), - } - } -} - -impl Index for CryptoStreams { - type Output = CryptoStream; - fn index(&self, space: PNSpace) -> &Self::Output { - let (initial, hs, app) = match self { - Self::Initial { - initial, - handshake, - application, - } => (Some(initial), Some(handshake), application), - Self::Handshake { - handshake, - application, - } => (None, Some(handshake), application), - Self::ApplicationData { application } => (None, None, application), - }; - match space { - PNSpace::Initial => initial.expect("Initial state dropped!"), - PNSpace::Handshake => hs.expect("Handshake state dropped!"), - PNSpace::ApplicationData => app, - } - } -} - -impl IndexMut for CryptoStreams { - fn index_mut(&mut self, space: PNSpace) -> &mut Self::Output { - let (initial, hs, app) = match self { - Self::Initial { - initial, - handshake, - application, - } => (Some(initial), Some(handshake), application), - Self::Handshake { - handshake, - application, - } => (None, Some(handshake), application), - Self::ApplicationData { application } => (None, None, application), - }; - match space { - PNSpace::Initial => initial.expect("Initial state dropped!"), - PNSpace::Handshake => hs.expect("Handshake state dropped!"), - PNSpace::ApplicationData => app, - } - } -} - #[derive(Debug, Clone)] pub struct CryptoRecoveryToken { - space: PNSpace, + epoch: u16, offset: u64, length: usize, } diff --git a/third_party/rust/neqo-transport/src/dump.rs b/third_party/rust/neqo-transport/src/dump.rs index e8f5b32ae943..5d129ef8ee6a 100644 --- a/third_party/rust/neqo-transport/src/dump.rs +++ b/third_party/rust/neqo-transport/src/dump.rs @@ -8,16 +8,16 @@ // e.g. "RUST_LOG=neqo_transport::dump neqo-client ..." use crate::connection::Connection; -use crate::frame::Frame; -use crate::packet::{PacketNumber, PacketType}; +use crate::frame::decode_frame; +use crate::packet::PacketHdr; use neqo_common::{qdebug, Decoder}; #[allow(clippy::module_name_repetitions)] -pub fn dump_packet(conn: &Connection, dir: &str, pt: PacketType, pn: PacketNumber, payload: &[u8]) { +pub fn dump_packet(conn: &Connection, dir: &str, hdr: &PacketHdr, payload: &[u8]) { let mut s = String::from(""); let mut d = Decoder::from(payload); while d.remaining() > 0 { - let f = match Frame::decode(&mut d) { + let f = match decode_frame(&mut d) { Ok(f) => f, Err(_) => { s.push_str(" [broken]..."); @@ -28,5 +28,5 @@ pub fn dump_packet(conn: &Connection, dir: &str, pt: PacketType, pn: PacketNumbe s.push_str(&format!("\n {} {}", dir, &x)); } } - qdebug!([conn], "pn={} type={:?}{}", pn, pt, s); + qdebug!([conn], "pn={} type={:?}{}", hdr.pn, hdr.tipe, s); } diff --git a/third_party/rust/neqo-transport/src/flow_mgr.rs b/third_party/rust/neqo-transport/src/flow_mgr.rs index aedd1635928f..c3f02eabd5d2 100644 --- a/third_party/rust/neqo-transport/src/flow_mgr.rs +++ b/third_party/rust/neqo-transport/src/flow_mgr.rs @@ -11,13 +11,13 @@ use std::collections::HashMap; use std::mem; use neqo_common::{qinfo, qtrace, qwarn, Encoder}; +use neqo_crypto::Epoch; use crate::frame::{Frame, StreamType}; use crate::recovery::RecoveryToken; use crate::recv_stream::RecvStreams; use crate::send_stream::SendStreams; use crate::stream_id::{StreamId, StreamIndex, StreamIndexes}; -use crate::tracking::PNSpace; use crate::AppError; pub type FlowControlRecoveryToken = Frame; @@ -37,6 +37,8 @@ pub struct FlowMgr { used_data: u64, max_data: u64, + + need_close_frame: bool, } impl FlowMgr { @@ -122,16 +124,6 @@ impl FlowMgr { .insert((stream_id, mem::discriminant(&frame)), frame); } - /// Don't send stream data updates if no more data is coming - pub fn clear_max_stream_data(&mut self, stream_id: StreamId) { - let frame = Frame::MaxStreamData { - stream_id, - maximum_stream_data: 0, - }; - self.from_streams - .remove(&(stream_id, mem::discriminant(&frame))); - } - /// Indicate to receiving remote we need more credits pub fn stream_data_blocked(&mut self, stream_id: StreamId, stream_data_limit: u64) { let frame = Frame::StreamDataBlocked { @@ -174,6 +166,14 @@ impl FlowMgr { } } + pub fn need_close_frame(&self) -> bool { + self.need_close_frame + } + + pub fn set_need_close_frame(&mut self, new: bool) { + self.need_close_frame = new + } + pub(crate) fn acked( &mut self, token: FlowControlRecoveryToken, @@ -291,10 +291,10 @@ impl FlowMgr { pub(crate) fn get_frame( &mut self, - space: PNSpace, + epoch: Epoch, remaining: usize, ) -> Option<(Frame, Option)> { - if space != PNSpace::ApplicationData { + if epoch != 3 { return None; } diff --git a/third_party/rust/neqo-transport/src/frame.rs b/third_party/rust/neqo-transport/src/frame.rs index 08e437278681..d29a6a1157fa 100644 --- a/third_party/rust/neqo-transport/src/frame.rs +++ b/third_party/rust/neqo-transport/src/frame.rs @@ -6,9 +6,9 @@ // Directly relating to QUIC frames. -use neqo_common::{matches, qdebug, qtrace, Decoder, Encoder}; +use neqo_common::{matches, qdebug, Decoder, Encoder}; +use neqo_crypto::Epoch; -use crate::packet::PacketType; use crate::stream_id::{StreamId, StreamIndex}; use crate::{AppError, TransportError}; use crate::{ConnectionError, Error, Res}; @@ -43,7 +43,6 @@ const FRAME_TYPE_PATH_CHALLENGE: FrameType = 0x1a; const FRAME_TYPE_PATH_RESPONSE: FrameType = 0x1b; const FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT: FrameType = 0x1c; const FRAME_TYPE_CONNECTION_CLOSE_APPLICATION: FrameType = 0x1d; -const FRAME_TYPE_HANDSHAKE_DONE: FrameType = 0x1e; const STREAM_FRAME_BIT_FIN: u64 = 0x01; const STREAM_FRAME_BIT_LEN: u64 = 0x02; @@ -59,15 +58,15 @@ pub enum StreamType { impl StreamType { fn frame_type_bit(self) -> u64 { match self { - Self::BiDi => 0, - Self::UniDi => 1, + StreamType::BiDi => 0, + StreamType::UniDi => 1, } } - fn from_type_bit(bit: u64) -> Self { + fn from_type_bit(bit: u64) -> StreamType { if (bit & 0x01) == 0 { - Self::BiDi + StreamType::BiDi } else { - Self::UniDi + StreamType::UniDi } } } @@ -81,22 +80,22 @@ pub enum CloseError { impl CloseError { fn frame_type_bit(self) -> u64 { match self { - Self::Transport(_) => 0, - Self::Application(_) => 1, + CloseError::Transport(_) => 0, + CloseError::Application(_) => 1, } } - fn from_type_bit(bit: u64, code: u64) -> Self { + fn from_type_bit(bit: u64, code: u64) -> CloseError { if (bit & 0x01) == 0 { - Self::Transport(code) + CloseError::Transport(code) } else { - Self::Application(code) + CloseError::Application(code) } } fn code(&self) -> u64 { match self { - Self::Transport(c) | Self::Application(c) => *c, + CloseError::Transport(c) | CloseError::Application(c) => *c, } } } @@ -104,8 +103,8 @@ impl CloseError { impl From for CloseError { fn from(err: ConnectionError) -> Self { match err { - ConnectionError::Transport(c) => Self::Transport(c.code()), - ConnectionError::Application(c) => Self::Application(c), + ConnectionError::Transport(c) => CloseError::Transport(c.code()), + ConnectionError::Application(c) => CloseError::Application(c), } } } @@ -191,20 +190,19 @@ pub enum Frame { frame_type: u64, reason_phrase: Vec, }, - HandshakeDone, } impl Frame { pub fn get_type(&self) -> FrameType { match self { - Self::Padding => FRAME_TYPE_PADDING, - Self::Ping => FRAME_TYPE_PING, - Self::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN. - Self::ResetStream { .. } => FRAME_TYPE_RST_STREAM, - Self::StopSending { .. } => FRAME_TYPE_STOP_SENDING, - Self::Crypto { .. } => FRAME_TYPE_CRYPTO, - Self::NewToken { .. } => FRAME_TYPE_NEW_TOKEN, - Self::Stream { + Frame::Padding => FRAME_TYPE_PADDING, + Frame::Ping => FRAME_TYPE_PING, + Frame::Ack { .. } => FRAME_TYPE_ACK, // We don't do ACK ECN. + Frame::ResetStream { .. } => FRAME_TYPE_RST_STREAM, + Frame::StopSending { .. } => FRAME_TYPE_STOP_SENDING, + Frame::Crypto { .. } => FRAME_TYPE_CRYPTO, + Frame::NewToken { .. } => FRAME_TYPE_NEW_TOKEN, + Frame::Stream { fin, offset, fill, .. } => { let mut t = FRAME_TYPE_STREAM; @@ -219,29 +217,28 @@ impl Frame { } t } - Self::MaxData { .. } => FRAME_TYPE_MAX_DATA, - Self::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA, - Self::MaxStreams { stream_type, .. } => { + Frame::MaxData { .. } => FRAME_TYPE_MAX_DATA, + Frame::MaxStreamData { .. } => FRAME_TYPE_MAX_STREAM_DATA, + Frame::MaxStreams { stream_type, .. } => { FRAME_TYPE_MAX_STREAMS_BIDI + stream_type.frame_type_bit() } - Self::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED, - Self::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED, - Self::StreamsBlocked { stream_type, .. } => { + Frame::DataBlocked { .. } => FRAME_TYPE_DATA_BLOCKED, + Frame::StreamDataBlocked { .. } => FRAME_TYPE_STREAM_DATA_BLOCKED, + Frame::StreamsBlocked { stream_type, .. } => { FRAME_TYPE_STREAMS_BLOCKED_BIDI + stream_type.frame_type_bit() } - Self::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID, - Self::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID, - Self::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE, - Self::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE, - Self::ConnectionClose { error_code, .. } => { + Frame::NewConnectionId { .. } => FRAME_TYPE_NEW_CONNECTION_ID, + Frame::RetireConnectionId { .. } => FRAME_TYPE_RETIRE_CONNECTION_ID, + Frame::PathChallenge { .. } => FRAME_TYPE_PATH_CHALLENGE, + Frame::PathResponse { .. } => FRAME_TYPE_PATH_RESPONSE, + Frame::ConnectionClose { error_code, .. } => { FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT + error_code.frame_type_bit() } - Self::HandshakeDone => FRAME_TYPE_HANDSHAKE_DONE, } } /// Create a CRYPTO frame that fits the available space and its length. - pub fn new_crypto(offset: u64, data: &[u8], space: usize) -> (Self, usize) { + pub fn new_crypto(offset: u64, data: &[u8], space: usize) -> (Frame, usize) { // Subtract the frame type and offset from available space. let mut remaining = space - 1 - Encoder::varint_len(offset); // Then subtract space for the length field. @@ -249,7 +246,7 @@ impl Frame { remaining -= Encoder::varint_len(u64::try_from(data_len).unwrap()); remaining = min(data.len(), remaining); ( - Self::Crypto { + Frame::Crypto { offset, data: data[..remaining].to_vec(), }, @@ -265,7 +262,7 @@ impl Frame { data: &[u8], fin: bool, space: usize, - ) -> Option<(Self, usize)> { + ) -> Option<(Frame, usize)> { let mut overhead = 1 + Encoder::varint_len(stream_id); if offset > 0 { overhead += Encoder::varint_len(offset); @@ -312,7 +309,7 @@ impl Frame { ); Some(( - Self::Stream { + Frame::Stream { stream_id: stream_id.into(), offset, data: data[..data_len].to_vec(), @@ -327,8 +324,8 @@ impl Frame { enc.encode_varint(self.get_type()); match self { - Self::Padding | Self::Ping => (), - Self::Ack { + Frame::Padding | Frame::Ping => (), + Frame::Ack { largest_acknowledged, ack_delay, first_ack_range, @@ -343,7 +340,7 @@ impl Frame { enc.encode_varint(r.range); } } - Self::ResetStream { + Frame::ResetStream { stream_id, application_error_code, final_size, @@ -352,21 +349,21 @@ impl Frame { enc.encode_varint(*application_error_code); enc.encode_varint(*final_size); } - Self::StopSending { + Frame::StopSending { stream_id, application_error_code, } => { enc.encode_varint(stream_id.as_u64()); enc.encode_varint(*application_error_code); } - Self::Crypto { offset, data } => { + Frame::Crypto { offset, data } => { enc.encode_varint(*offset); enc.encode_vvec(&data); } - Self::NewToken { token } => { + Frame::NewToken { token } => { enc.encode_vvec(token); } - Self::Stream { + Frame::Stream { stream_id, offset, data, @@ -383,35 +380,35 @@ impl Frame { enc.encode_vvec(&data); } } - Self::MaxData { maximum_data } => { + Frame::MaxData { maximum_data } => { enc.encode_varint(*maximum_data); } - Self::MaxStreamData { + Frame::MaxStreamData { stream_id, maximum_stream_data, } => { enc.encode_varint(stream_id.as_u64()); enc.encode_varint(*maximum_stream_data); } - Self::MaxStreams { + Frame::MaxStreams { maximum_streams, .. } => { enc.encode_varint(maximum_streams.as_u64()); } - Self::DataBlocked { data_limit } => { + Frame::DataBlocked { data_limit } => { enc.encode_varint(*data_limit); } - Self::StreamDataBlocked { + Frame::StreamDataBlocked { stream_id, stream_data_limit, } => { enc.encode_varint(stream_id.as_u64()); enc.encode_varint(*stream_data_limit); } - Self::StreamsBlocked { stream_limit, .. } => { + Frame::StreamsBlocked { stream_limit, .. } => { enc.encode_varint(stream_limit.as_u64()); } - Self::NewConnectionId { + Frame::NewConnectionId { sequence_number, retire_prior, connection_id, @@ -423,16 +420,16 @@ impl Frame { enc.encode(connection_id); enc.encode(stateless_reset_token); } - Self::RetireConnectionId { sequence_number } => { + Frame::RetireConnectionId { sequence_number } => { enc.encode_varint(*sequence_number); } - Self::PathChallenge { data } => { + Frame::PathChallenge { data } => { enc.encode(data); } - Self::PathResponse { data } => { + Frame::PathResponse { data } => { enc.encode(data); } - Self::ConnectionClose { + Frame::ConnectionClose { error_code, frame_type, reason_phrase, @@ -441,12 +438,11 @@ impl Frame { enc.encode_varint(*frame_type); enc.encode_vvec(reason_phrase); } - Self::HandshakeDone => (), } } pub fn ack_eliciting(&self) -> bool { - !matches!(self, Self::Ack { .. } | Self::Padding | Self::ConnectionClose { .. }) + !matches!(self, Frame::Ack { .. } | Frame::Padding | Frame::ConnectionClose { .. }) } /// Converts AckRanges as encoded in a ACK frame (see -transport @@ -494,12 +490,12 @@ impl Frame { pub fn dump(&self) -> Option { match self { - Self::Crypto { offset, data } => Some(format!( + Frame::Crypto { offset, data } => Some(format!( "Crypto {{ offset: {}, len: {} }}", offset, data.len() )), - Self::Stream { + Frame::Stream { stream_id, offset, fill, @@ -513,184 +509,187 @@ impl Frame { data.len(), fin, )), - Self::Padding => None, + Frame::Padding => None, _ => Some(format!("{:?}", self)), } } - pub fn is_allowed(&self, pt: PacketType) -> bool { - if matches!(self, Self::Padding | Self::Ping) { + pub fn is_allowed(&self, epoch: Epoch) -> bool { + qdebug!("is_allowed {:?} {}", self, epoch); + if matches!(self, Frame::Padding | Frame::Ping) { true - } else if matches!(self, Self::Crypto {..} | Self::Ack {..} | Self::ConnectionClose { error_code: CloseError::Transport(_), .. }) + } else if matches!(self, Frame::Crypto {..} | Frame::Ack {..} | Frame::ConnectionClose { error_code: CloseError::Transport(_), .. }) { - pt != PacketType::ZeroRtt - } else if matches!(self, Self::NewToken {..} | Self::ConnectionClose {..}) { - pt == PacketType::Short + epoch != 1 + } else if matches!(self, Frame::NewToken {..} | Frame::ConnectionClose {..}) { + epoch >= 3 } else { - pt == PacketType::ZeroRtt || pt == PacketType::Short + epoch == 1 || epoch >= 3 // Application data } } +} - pub fn decode(dec: &mut Decoder) -> Res { - macro_rules! d { - ($d:expr) => { - match $d { - Some(v) => v, - _ => return Err(Error::NoMoreData), - } - }; - } - macro_rules! dv { - ($d:expr) => { - d!($d.decode_varint()) - }; - } - - // TODO(ekr@rtfm.com): check for minimal encoding - let t = d!(dec.decode_varint()); - match t { - FRAME_TYPE_PADDING => Ok(Self::Padding), - FRAME_TYPE_PING => Ok(Self::Ping), - FRAME_TYPE_RST_STREAM => Ok(Self::ResetStream { - stream_id: dv!(dec).into(), - application_error_code: d!(dec.decode_varint()), - final_size: match dec.decode_varint() { - Some(v) => v, - _ => return Err(Error::NoMoreData), - }, - }), - FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => { - let la = dv!(dec); - let ad = dv!(dec); - let nr = dv!(dec); - let fa = dv!(dec); - let mut arr: Vec = Vec::with_capacity(nr as usize); - for _ in 0..nr { - let ar = AckRange { - gap: dv!(dec), - range: dv!(dec), - }; - arr.push(ar); - } - - // Now check for the values for ACK_ECN. - if t == FRAME_TYPE_ACK_ECN { - dv!(dec); - dv!(dec); - dv!(dec); - } - - Ok(Self::Ack { - largest_acknowledged: la, - ack_delay: ad, - first_ack_range: fa, - ack_ranges: arr, - }) +#[allow(clippy::module_name_repetitions)] +pub fn decode_frame(dec: &mut Decoder) -> Res { + macro_rules! d { + ($d:expr) => { + match $d { + Some(v) => v, + _ => return Err(Error::NoMoreData), } - FRAME_TYPE_STOP_SENDING => Ok(Self::StopSending { - stream_id: dv!(dec).into(), - application_error_code: d!(dec.decode_varint()), - }), - FRAME_TYPE_CRYPTO => { - let o = dv!(dec); - Ok(Self::Crypto { - offset: o, - data: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy - }) - } - FRAME_TYPE_NEW_TOKEN => { - Ok(Self::NewToken { - token: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy - }) - } - FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => { - let s = dv!(dec); - let o = if t & STREAM_FRAME_BIT_OFF == 0 { - 0 - } else { - dv!(dec) + }; + } + macro_rules! dv { + ($d:expr) => { + d!($d.decode_varint()) + }; + } + + // TODO(ekr@rtfm.com): check for minimal encoding + let t = d!(dec.decode_varint()); + qdebug!("Frame type byte={:0x}", t); + match t { + FRAME_TYPE_PADDING => Ok(Frame::Padding), + FRAME_TYPE_PING => Ok(Frame::Ping), + FRAME_TYPE_RST_STREAM => Ok(Frame::ResetStream { + stream_id: dv!(dec).into(), + application_error_code: d!(dec.decode_varint()), + final_size: match dec.decode_varint() { + Some(v) => v, + _ => return Err(Error::NoMoreData), + }, + }), + FRAME_TYPE_ACK | FRAME_TYPE_ACK_ECN => { + let la = dv!(dec); + let ad = dv!(dec); + let nr = dv!(dec); + let fa = dv!(dec); + let mut arr: Vec = Vec::with_capacity(nr as usize); + for _ in 0..nr { + let ar = AckRange { + gap: dv!(dec), + range: dv!(dec), }; - let fill = (t & STREAM_FRAME_BIT_LEN) == 0; - let data = if fill { - qtrace!("STREAM frame, extends to the end of the packet"); - dec.decode_remainder() - } else { - qtrace!("STREAM frame, with length"); - d!(dec.decode_vvec()) - }; - Ok(Self::Stream { - fin: (t & STREAM_FRAME_BIT_FIN) != 0, - stream_id: s.into(), - offset: o, - data: data.to_vec(), // TODO(mt) unnecessary copy. - fill, - }) + arr.push(ar); } - FRAME_TYPE_MAX_DATA => Ok(Self::MaxData { - maximum_data: dv!(dec), - }), - FRAME_TYPE_MAX_STREAM_DATA => Ok(Self::MaxStreamData { - stream_id: dv!(dec).into(), - maximum_stream_data: dv!(dec), - }), - FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => Ok(Self::MaxStreams { + + // Now check for the values for ACK_ECN. + if t == FRAME_TYPE_ACK_ECN { + dv!(dec); + dv!(dec); + dv!(dec); + } + + Ok(Frame::Ack { + largest_acknowledged: la, + ack_delay: ad, + first_ack_range: fa, + ack_ranges: arr, + }) + } + FRAME_TYPE_STOP_SENDING => Ok(Frame::StopSending { + stream_id: dv!(dec).into(), + application_error_code: d!(dec.decode_varint()), + }), + FRAME_TYPE_CRYPTO => { + let o = dv!(dec); + Ok(Frame::Crypto { + offset: o, + data: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy + }) + } + FRAME_TYPE_NEW_TOKEN => { + Ok(Frame::NewToken { + token: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy + }) + } + FRAME_TYPE_STREAM..=FRAME_TYPE_STREAM_MAX => { + let s = dv!(dec); + let o = if t & STREAM_FRAME_BIT_OFF == 0 { + 0 + } else { + dv!(dec) + }; + qdebug!("STREAM {}", t); + let fill = (t & STREAM_FRAME_BIT_LEN) == 0; + let data = if fill { + qdebug!("STREAM frame extends to the end of the packet"); + dec.decode_remainder() + } else { + qdebug!("STREAM frame has a length"); + d!(dec.decode_vvec()) + }; + Ok(Frame::Stream { + fin: (t & STREAM_FRAME_BIT_FIN) != 0, + stream_id: s.into(), + offset: o, + data: data.to_vec(), // TODO(mt) unnecessary copy. + fill, + }) + } + FRAME_TYPE_MAX_DATA => Ok(Frame::MaxData { + maximum_data: dv!(dec), + }), + FRAME_TYPE_MAX_STREAM_DATA => Ok(Frame::MaxStreamData { + stream_id: dv!(dec).into(), + maximum_stream_data: dv!(dec), + }), + FRAME_TYPE_MAX_STREAMS_BIDI | FRAME_TYPE_MAX_STREAMS_UNIDI => Ok(Frame::MaxStreams { + stream_type: StreamType::from_type_bit(t), + maximum_streams: StreamIndex::new(dv!(dec)), + }), + + FRAME_TYPE_DATA_BLOCKED => Ok(Frame::DataBlocked { + data_limit: dv!(dec), + }), + FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Frame::StreamDataBlocked { + stream_id: dv!(dec).into(), + stream_data_limit: dv!(dec), + }), + FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => { + Ok(Frame::StreamsBlocked { stream_type: StreamType::from_type_bit(t), - maximum_streams: StreamIndex::new(dv!(dec)), - }), - - FRAME_TYPE_DATA_BLOCKED => Ok(Self::DataBlocked { - data_limit: dv!(dec), - }), - FRAME_TYPE_STREAM_DATA_BLOCKED => Ok(Self::StreamDataBlocked { - stream_id: dv!(dec).into(), - stream_data_limit: dv!(dec), - }), - FRAME_TYPE_STREAMS_BLOCKED_BIDI | FRAME_TYPE_STREAMS_BLOCKED_UNIDI => { - Ok(Self::StreamsBlocked { - stream_type: StreamType::from_type_bit(t), - stream_limit: StreamIndex::new(dv!(dec)), - }) - } - FRAME_TYPE_NEW_CONNECTION_ID => { - let s = dv!(dec); - let retire_prior = dv!(dec); - let cid = d!(dec.decode_vec(1)).to_vec(); // TODO(mt) unnecessary copy - let srt = d!(dec.decode(16)); - let mut srtv: [u8; 16] = [0; 16]; - srtv.copy_from_slice(&srt); - - Ok(Self::NewConnectionId { - sequence_number: s, - retire_prior, - connection_id: cid, - stateless_reset_token: srtv, - }) - } - FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Self::RetireConnectionId { - sequence_number: dv!(dec), - }), - FRAME_TYPE_PATH_CHALLENGE => { - let data = d!(dec.decode(8)); - let mut datav: [u8; 8] = [0; 8]; - datav.copy_from_slice(&data); - Ok(Self::PathChallenge { data: datav }) - } - FRAME_TYPE_PATH_RESPONSE => { - let data = d!(dec.decode(8)); - let mut datav: [u8; 8] = [0; 8]; - datav.copy_from_slice(&data); - Ok(Self::PathResponse { data: datav }) - } - FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => { - Ok(Self::ConnectionClose { - error_code: CloseError::from_type_bit(t, d!(dec.decode_varint())), - frame_type: dv!(dec), - reason_phrase: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy - }) - } - FRAME_TYPE_HANDSHAKE_DONE => Ok(Self::HandshakeDone), - _ => Err(Error::UnknownFrameType), + stream_limit: StreamIndex::new(dv!(dec)), + }) } + FRAME_TYPE_NEW_CONNECTION_ID => { + let s = dv!(dec); + let retire_prior = dv!(dec); + let cid = d!(dec.decode_vec(1)).to_vec(); // TODO(mt) unnecessary copy + let srt = d!(dec.decode(16)); + let mut srtv: [u8; 16] = [0; 16]; + srtv.copy_from_slice(&srt); + + Ok(Frame::NewConnectionId { + sequence_number: s, + retire_prior, + connection_id: cid, + stateless_reset_token: srtv, + }) + } + FRAME_TYPE_RETIRE_CONNECTION_ID => Ok(Frame::RetireConnectionId { + sequence_number: dv!(dec), + }), + FRAME_TYPE_PATH_CHALLENGE => { + let data = d!(dec.decode(8)); + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(&data); + Ok(Frame::PathChallenge { data: datav }) + } + FRAME_TYPE_PATH_RESPONSE => { + let data = d!(dec.decode(8)); + let mut datav: [u8; 8] = [0; 8]; + datav.copy_from_slice(&data); + Ok(Frame::PathResponse { data: datav }) + } + FRAME_TYPE_CONNECTION_CLOSE_TRANSPORT | FRAME_TYPE_CONNECTION_CLOSE_APPLICATION => { + Ok(Frame::ConnectionClose { + error_code: CloseError::from_type_bit(t, d!(dec.decode_varint())), + frame_type: dv!(dec), + reason_phrase: d!(dec.decode_vvec()).to_vec(), // TODO(mt) unnecessary copy + }) + } + _ => Err(Error::UnknownFrameType), } } @@ -711,7 +710,7 @@ mod tests { f.marshal(&mut d); assert_eq!(d, Encoder::from_hex(s)); - let f2 = Frame::decode(&mut d.as_decoder()).unwrap(); + let f2 = decode_frame(&mut d.as_decoder()).unwrap(); assert_eq!(*f, f2); } @@ -743,12 +742,12 @@ mod tests { // Try to parse ACK_ECN without ECN values let enc = Encoder::from_hex("035234523502523601020304"); let mut dec = enc.as_decoder(); - assert_eq!(Frame::decode(&mut dec).unwrap_err(), Error::NoMoreData); + assert_eq!(decode_frame(&mut dec).unwrap_err(), Error::NoMoreData); // Try to parse ACK_ECN without ECN values let enc = Encoder::from_hex("035234523502523601020304010203"); let mut dec = enc.as_decoder(); - assert_eq!(Frame::decode(&mut dec).unwrap(), f); + assert_eq!(decode_frame(&mut dec).unwrap(), f); } #[test] @@ -992,7 +991,7 @@ mod tests { ack_frame.marshal(&mut enc); println!("Encoded ACK={}", hex(&enc[..])); - let f = Frame::decode(&mut enc.as_decoder()).unwrap(); + let f = decode_frame(&mut enc.as_decoder()).unwrap(); if let Frame::Ack { largest_acknowledged, ack_delay, diff --git a/third_party/rust/neqo-transport/src/lib.rs b/third_party/rust/neqo-transport/src/lib.rs index ecf50a03b5af..3ada8da0b56a 100644 --- a/third_party/rust/neqo-transport/src/lib.rs +++ b/third_party/rust/neqo-transport/src/lib.rs @@ -5,13 +5,10 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::use_self)] use neqo_common::qinfo; use neqo_crypto; -mod cc; -mod cid; mod connection; mod crypto; mod dump; @@ -28,18 +25,16 @@ mod stream_id; mod tparams; mod tracking; -pub use self::cid::ConnectionIdManager; -pub use self::connection::{Connection, FixedConnectionIdManager, Output, Role, State}; +pub use self::connection::{ + Connection, ConnectionIdManager, FixedConnectionIdManager, Output, Role, State, +}; pub use self::events::{ConnectionEvent, ConnectionEvents}; pub use self::frame::CloseError; pub use self::frame::StreamType; pub use self::tparams::{tp_constants, TransportParameter}; /// The supported version of the QUIC protocol. -pub type Version = u32; -pub const QUIC_VERSION: Version = 0xff00_0000 + 25; - -const LOCAL_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); // 1 minute +pub const QUIC_VERSION: u32 = 0xff00_0018; type TransportError = u64; @@ -73,38 +68,31 @@ pub enum Error { InvalidResumptionToken, InvalidRetry, InvalidStreamId, - // Packet protection keys aren't available yet, or they have been discarded. KeysNotFound, - // An attempt to update keys can be blocked if - // a packet sent with the current keys hasn't been acknowledged. - KeyUpdateBlocked, NoMoreData, - NotConnected, - PacketNumberOverlap, PeerError(TransportError), TooMuchData, UnexpectedMessage, UnknownFrameType, VersionNegotiation, WrongRole, - KeysDiscarded, } impl Error { pub fn code(&self) -> TransportError { match self { - Self::NoError => 0, - Self::ServerBusy => 2, - Self::FlowControlError => 3, - Self::StreamLimitError => 4, - Self::StreamStateError => 5, - Self::FinalSizeError => 6, - Self::FrameEncodingError => 7, - Self::TransportParameterError => 8, - Self::ProtocolViolation => 10, - Self::InvalidMigration => 12, - Self::CryptoAlert(a) => 0x100 + u64::from(*a), - Self::PeerError(a) => *a, + Error::NoError => 0, + Error::ServerBusy => 2, + Error::FlowControlError => 3, + Error::StreamLimitError => 4, + Error::StreamStateError => 5, + Error::FinalSizeError => 6, + Error::FrameEncodingError => 7, + Error::TransportParameterError => 8, + Error::ProtocolViolation => 10, + Error::InvalidMigration => 12, + Error::CryptoAlert(a) => 0x100 + u64::from(*a), + Error::PeerError(a) => *a, // All the rest are internal errors. _ => 1, } @@ -114,20 +102,20 @@ impl Error { impl From for Error { fn from(err: neqo_crypto::Error) -> Self { qinfo!("Crypto operation failed {:?}", err); - Self::CryptoError(err) + Error::CryptoError(err) } } impl From for Error { fn from(_: std::num::TryFromIntError) -> Self { - Self::IntegerOverflow + Error::IntegerOverflow } } impl ::std::error::Error for Error { fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { match self { - Self::CryptoError(e) => Some(e), + Error::CryptoError(e) => Some(e), _ => None, } } @@ -150,7 +138,7 @@ pub enum ConnectionError { impl ConnectionError { pub fn app_code(&self) -> Option { match self { - Self::Application(e) => Some(*e), + ConnectionError::Application(e) => Some(*e), _ => None, } } @@ -159,8 +147,8 @@ impl ConnectionError { impl From for ConnectionError { fn from(err: CloseError) -> Self { match err { - CloseError::Transport(c) => Self::Transport(Error::PeerError(c)), - CloseError::Application(c) => Self::Application(c), + CloseError::Transport(c) => ConnectionError::Transport(Error::PeerError(c)), + CloseError::Application(c) => ConnectionError::Application(c), } } } diff --git a/third_party/rust/neqo-transport/src/packet.rs b/third_party/rust/neqo-transport/src/packet.rs index dee780410943..d2609ea1efcf 100644 --- a/third_party/rust/neqo-transport/src/packet.rs +++ b/third_party/rust/neqo-transport/src/packet.rs @@ -5,20 +5,19 @@ // except according to those terms. // Encoding and decoding packets off the wire. -use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef}; -use crate::crypto::{CryptoDxState, CryptoStates}; -use crate::tracking::PNSpace; -use crate::{Error, Res, Version, QUIC_VERSION}; -use neqo_common::{hex, qerror, qtrace, Decoder, Encoder}; -use neqo_crypto::{aead::Aead, hkdf, random, TLS_AES_128_GCM_SHA256, TLS_VERSION_1_3}; +// A lot of methods and types contain the word Packet +#![allow(clippy::module_name_repetitions)] -use std::cell::RefCell; -use std::convert::TryFrom; -use std::fmt; -use std::iter::ExactSizeIterator; -use std::ops::{Deref, DerefMut, Range}; -use std::time::Instant; +use rand::Rng; + +use neqo_common::{hex, matches, qtrace, Decoder, Encoder}; +use neqo_crypto::aead::Aead; +use neqo_crypto::Epoch; + +use std::convert::{TryFrom, TryInto}; + +use crate::{Error, Res}; const PACKET_TYPE_INITIAL: u8 = 0x0; const PACKET_TYPE_0RTT: u8 = 0x01; @@ -27,491 +26,198 @@ const PACKET_TYPE_RETRY: u8 = 0x03; const PACKET_BIT_LONG: u8 = 0x80; const PACKET_BIT_SHORT: u8 = 0x00; -const PACKET_BIT_KEY_PHASE: u8 = 0x04; const PACKET_BIT_FIXED_QUIC: u8 = 0x40; -const PACKET_HP_MASK_LONG: u8 = 0x0f; -const PACKET_HP_MASK_SHORT: u8 = 0x1f; - const SAMPLE_SIZE: usize = 16; -const SAMPLE_OFFSET: usize = 4; -pub type PacketNumber = u64; +const AUTH_TAG_LEN: usize = 16; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, PartialEq)] pub enum PacketType { - VersionNegotiation, - Initial, - Handshake, - ZeroRtt, - Retry, Short, - OtherVersion, + ZeroRTT, + Handshake, + VN(Vec), // List of versions + Initial(Vec), // Token + Retry { odcid: ConnectionId, token: Vec }, +} + +impl Default for PacketType { + fn default() -> Self { + PacketType::Short + } } impl PacketType { - #[must_use] - fn code(self) -> u8 { + fn code(&self) -> u8 { match self { - Self::Initial => PACKET_TYPE_INITIAL, - Self::ZeroRtt => PACKET_TYPE_0RTT, - Self::Handshake => PACKET_TYPE_HANDSHAKE, - Self::Retry => PACKET_TYPE_RETRY, + PacketType::Initial(..) => PACKET_TYPE_INITIAL, + PacketType::ZeroRTT => PACKET_TYPE_0RTT, + PacketType::Handshake => PACKET_TYPE_HANDSHAKE, + PacketType::Retry { .. } => PACKET_TYPE_RETRY, _ => panic!("shouldn't be here"), } } } -/// The AEAD used for Retry is fixed, so use this. -fn make_retry_aead() -> Aead { - #[cfg(debug_assertions)] - ::neqo_crypto::assert_initialized(); +pub type Version = u32; +pub type PacketNumber = u64; - let secret = hkdf::import_key( - TLS_VERSION_1_3, - TLS_AES_128_GCM_SHA256, - &[ - 0x65, 0x6e, 0x61, 0xe3, 0x36, 0xae, 0x94, 0x17, 0xf7, 0xf0, 0xed, 0xd8, 0xd7, 0x8d, - 0x46, 0x1e, 0x2a, 0xa7, 0x08, 0x4a, 0xba, 0x7a, 0x14, 0xc1, 0xe9, 0xf7, 0x26, 0xd5, - 0x57, 0x09, 0x16, 0x9a, - ], - ) - .unwrap(); - Aead::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256, &secret, "quic ").unwrap() -} -thread_local!(static RETRY_AEAD: RefCell = RefCell::new(make_retry_aead())); -fn retry_expansion() -> usize { - if let Ok(ex) = RETRY_AEAD.try_with(|aead| aead.borrow().expansion()) { - ex - } else { - panic!("Unable to access Retry AEAD") - } -} +#[derive(Clone, Default, Eq, Hash, PartialEq)] +pub struct ConnectionId(pub Vec); -struct PacketBuilderoffsets { - /// The bits of the first octet that need masking. - first_byte_mask: u8, - /// The offset of the length field. - len: usize, - /// The location of the packet number field. - pn: Range, -} - -/// A packet builder that can be used to produce short packets and long packets. -/// This does not produce Retry or Version Negotiation. -pub struct PacketBuilder { - encoder: Encoder, - pn: PacketNumber, - header: Range, - offsets: PacketBuilderoffsets, -} - -impl PacketBuilder { - /// Start building a short header packet. - pub fn short(mut encoder: Encoder, key_phase: bool, dcid: &ConnectionId) -> Self { - let header_start = encoder.len(); - // TODO(mt) randomize the spin bit - encoder.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); - encoder.encode(&dcid); - Self { - encoder, - pn: u64::max_value(), - header: header_start..header_start, - offsets: PacketBuilderoffsets { - first_byte_mask: PACKET_HP_MASK_SHORT, - pn: 0..0, - len: 0, - }, - } - } - - /// Start building a long header packet. - /// For an Initial packet you will need to call initial_token(), - /// even if the token is empty. - pub fn long( - mut encoder: Encoder, - pt: PacketType, - dcid: &ConnectionId, - scid: &ConnectionId, - ) -> Self { - let header_start = encoder.len(); - encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.code() << 4); - encoder.encode_uint(4, QUIC_VERSION); - encoder.encode_vec(1, dcid); - encoder.encode_vec(1, scid); - Self { - encoder, - pn: u64::max_value(), - header: header_start..header_start, - offsets: PacketBuilderoffsets { - first_byte_mask: PACKET_HP_MASK_LONG, - pn: 0..0, - len: 0, - }, - } - } - - /// For an Initial packet, encode the token. - /// If you fail to do this, then you will not get a valid packet. - pub fn initial_token(&mut self, token: &[u8]) { - debug_assert_eq!( - self.encoder[self.header.start] & 0xb0, - PACKET_BIT_LONG | PACKET_TYPE_INITIAL << 4 - ); - self.encoder.encode_vvec(token); - } - - /// Add a packet number of the given size. - /// For a long header packet, this also inserts a dummy length. - /// The length is filled in after calling `build`. - pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) { - // Reserve space for a length in long headers. - if (self.encoder[self.header.start] & 0x80) == PACKET_BIT_LONG { - self.offsets.len = self.encoder.len(); - self.encoder.encode(&[0; 2]); - } - // Encode the packet number and save its offset. - debug_assert!(pn_len <= 4 && pn_len > 0); - let pn_offset = self.encoder.len(); - self.encoder.encode_uint(pn_len, pn); - self.offsets.pn = pn_offset..self.encoder.len(); - - // Now encode the packet number length and save the header length. - self.encoder[self.header.start] |= (pn_len - 1) as u8; - self.header.end = self.encoder.len(); - self.pn = pn; - } - - fn write_len(&mut self, expansion: usize) { - let len = self.encoder.len() - (self.offsets.len + 2) + expansion; - self.encoder[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8; - self.encoder[self.offsets.len + 1] = (len & 0xff) as u8; - } - - /// Build the packet and return the encoder. - pub fn build(mut self, crypto: &mut CryptoDxState) -> Res { - if self.offsets.len > 0 { - self.write_len(crypto.expansion()); - } - let hdr = &self.encoder[self.header.clone()]; - let body = &self.encoder[self.header.end..]; - qtrace!( - "Packet build pn={} hdr={} body={}", - self.pn, - hex(hdr), - hex(body) - ); - let ciphertext = crypto.encrypt(self.pn, hdr, body)?; - - // Calculate the mask. - let offset = SAMPLE_OFFSET - self.offsets.pn.len(); - assert!(offset + SAMPLE_SIZE <= ciphertext.len()); - let sample = &ciphertext[offset..offset + SAMPLE_SIZE]; - let mask = crypto.compute_mask(sample)?; - - // Apply the mask. - self.encoder[self.header.start] ^= mask[0] & self.offsets.first_byte_mask; - for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) { - self.encoder[j] ^= mask[i]; - } - - // Finally, cut off the plaintext and add back the ciphertext. - self.encoder.truncate(self.header.end); - self.encoder.encode(&ciphertext); - qtrace!("Packet built {}", hex(&self.encoder)); - Ok(self.encoder) - } - - /// Abort writing of this packet and return the encoder. - #[must_use] - pub fn abort(mut self) -> Encoder { - self.encoder.truncate(self.header.start); - self.encoder - } - - /// Work out if nothing was added after the header. - #[must_use] - pub fn is_empty(&self) -> bool { - self.encoder.len() == self.header.end - } - - /// Make a retry packet. - /// As this is a simple packet, this is just an associated function. - /// As Retry is odd (it has to be constructed with leading bytes), - /// this returns a Vec rather than building on an encoder. - pub fn retry(dcid: &[u8], scid: &[u8], token: &[u8], odcid: &[u8]) -> Res> { - let mut encoder = Encoder::default(); - encoder.encode_vec(1, odcid); - let start = encoder.len(); - encoder.encode_byte( - PACKET_BIT_LONG - | PACKET_BIT_FIXED_QUIC - | (PACKET_TYPE_RETRY << 4) - | (random(1)[0] & 0xf), - ); - encoder.encode_uint(4, QUIC_VERSION); - encoder.encode_vec(1, dcid); - encoder.encode_vec(1, scid); - debug_assert_ne!(token.len(), 0); - encoder.encode(token); - let tag = RETRY_AEAD - .try_with(|aead| -> Res> { - let mut buf = vec![0; aead.borrow().expansion()]; - Ok(aead.borrow().encrypt(0, &encoder, &[], &mut buf)?.to_vec()) - }) - .map_err(|e| { - qerror!("Unable to access Retry AEAD: {:?}", e); - Error::InternalError - })??; - encoder.encode(&tag); - let mut complete: Vec = encoder.into(); - Ok(complete.split_off(start)) - } - - /// Make a Version Negotiation packet. - pub fn version_negotiation(dcid: &[u8], scid: &[u8]) -> Vec { - let mut encoder = Encoder::default(); - let mut grease = random(5); - // This will not include the "QUIC bit" sometimes. Intentionally. - encoder.encode_byte(PACKET_BIT_LONG | (grease[4] & 0x7f)); - encoder.encode(&[0; 4]); // Zero version == VN. - encoder.encode_vec(1, dcid); - encoder.encode_vec(1, scid); - encoder.encode_uint(4, QUIC_VERSION); - // Add a greased version, using the randomness already generated. - for g in &mut grease[..4] { - *g = *g & 0xf0 | 0x0a; - } - encoder.encode(&grease[0..4]); - encoder.into() - } -} - -impl Deref for PacketBuilder { - type Target = Encoder; +impl std::ops::Deref for ConnectionId { + type Target = [u8]; fn deref(&self) -> &Self::Target { - &self.encoder + &self.0 } } -impl DerefMut for PacketBuilder { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.encoder +impl ConnectionId { + pub fn generate(len: usize) -> Self { + assert!(matches!(len, 0..=20)); + let mut v = vec![0; len]; + rand::thread_rng().fill(&mut v[..]); + Self(v) + } + + // Apply a wee bit of greasing here in picking a length between 8 and 20 bytes long. + pub fn generate_initial() -> ConnectionId { + let mut v = [0u8; 1]; + rand::thread_rng().fill(&mut v[..]); + // Bias selection toward picking 8 (>50% of the time). + let len: usize = ::std::cmp::max(8, 5 + (v[0] & (v[0] >> 4))).into(); + ConnectionId::generate(len) } } -impl Into for PacketBuilder { - fn into(self) -> Encoder { - self.encoder +impl ::std::fmt::Debug for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "CID {}", hex(&self.0)) } } -/// PublicPacket holds information from packets that is public only. This allows for -/// processing of packets prior to decryption. -pub struct PublicPacket<'a> { - /// The packet type. - packet_type: PacketType, - /// The recovered destination connection ID. - dcid: ConnectionIdRef<'a>, - /// The source connection ID, if this is a long header packet. - scid: Option>, - /// Any token that is included in the packet (Retry always has a token; Initial sometimes does). - /// This is empty when there is no token. - token: &'a [u8], - /// The size of the header, not including the packet number. - header_len: usize, - /// A reference to the entire packet, including the header. - data: &'a [u8], +impl ::std::fmt::Display for ConnectionId { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", hex(&self.0)) + } } -impl<'a> PublicPacket<'a> { - fn opt(v: Option) -> Res { - if let Some(v) = v { - Ok(v) - } else { - Err(Error::NoMoreData) +impl From<&[u8]> for ConnectionId { + fn from(buf: &[u8]) -> ConnectionId { + ConnectionId(Vec::from(buf)) + } +} + +pub trait ConnectionIdDecoder { + fn decode_cid(&self, dec: &mut Decoder) -> Option; +} + +#[derive(Default, Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct PacketHdr { + pub tbyte: u8, + pub tipe: PacketType, + pub version: Option, + pub dcid: ConnectionId, + pub scid: Option, + pub pn: PacketNumber, + pub epoch: Epoch, + pub hdr_len: usize, + body_len: usize, +} + +impl PacketHdr { + // Similar names are allowed here because + // dcid and scid are defined and commonly used in the spec. + #[allow(clippy::similar_names)] + pub fn new( + tbyte: u8, + tipe: PacketType, + version: Option, + dcid: ConnectionId, + scid: Option, + pn: PacketNumber, + epoch: Epoch, + ) -> Self { + Self { + tbyte, + tipe, + version, + dcid, + scid, + pn, + epoch, + hdr_len: 0, + body_len: 0, } } - /// Decode the type-specific portions of a long header. - /// This includes reading the length and the remainder of the packet. - /// Returns a tuple of any token and the length of the header. - fn decode_type_specific( - decoder: &mut Decoder<'a>, - packet_type: PacketType, - ) -> Res<(&'a [u8], usize)> { - if packet_type == PacketType::Retry { - let header_len = decoder.offset(); - let expansion = retry_expansion(); - let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?; - if token.is_empty() { - return Err(Error::InvalidPacket); + pub fn body_len(&self) -> usize { + self.body_len + } + + // header length plus auth tag + pub fn overhead(&self, aead: &Aead, pmtu: usize) -> usize { + match &self.tipe { + PacketType::Short => { + // Leading byte. + let mut len = 1; + len += self.dcid.0.len(); + len += pn_length(self.pn); + len + aead.expansion() } - Self::opt(decoder.decode(expansion))?; - return Ok((token, header_len)); - } - let token = if packet_type == PacketType::Initial { - Self::opt(decoder.decode_vvec())? - } else { - &[] - }; - let len = Self::opt(decoder.decode_varint())?; - let header_len = decoder.offset(); - let _body = Self::opt(decoder.decode(usize::try_from(len)?))?; - Ok((token, header_len)) - } + PacketType::VN(_) => unimplemented!("Can't get overhead for VN"), + PacketType::Retry { .. } => unimplemented!("Can't get overhead for Retry"), + PacketType::Initial(..) | PacketType::ZeroRTT | PacketType::Handshake => { + let pnl = pn_length(self.pn); - /// Decode the common parts of a packet. This provides minimal parsing and validation. - /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram. - pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> { - let mut decoder = Decoder::new(data); - let first = Self::opt(decoder.decode_byte())?; + // Leading byte. + let mut len = 1; + len += 4; // Version + len += 1; // DCID length + len += self.dcid.len(); + len += 1; // SCID length + len += self.scid.as_ref().unwrap().len(); - if first & 0x80 == PACKET_BIT_SHORT { - return if first & 0x40 == PACKET_BIT_FIXED_QUIC { - let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?; - if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE { - return Err(Error::InvalidPacket); + if let PacketType::Initial(token) = &self.tipe { + len += Encoder::varint_len(token.len().try_into().unwrap()); + len += token.len(); } - let header_len = decoder.offset(); - Ok(( - Self { - packet_type: PacketType::Short, - dcid, - scid: None, - token: &[], - header_len, - data, - }, - &[], - )) - } else { - Err(Error::InvalidPacket) - }; + + len += Encoder::varint_len((pnl + pmtu + aead.expansion()) as u64); + len += pnl; + len + aead.expansion() + } } + } +} - // Generic long header. - let version = Version::try_from(Self::opt(decoder.decode_uint(4))?).unwrap(); - let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); - let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?); +pub trait CryptoCtx { + fn compute_mask(&self, sample: &[u8]) -> Res>; + fn aead_decrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res>; + fn aead_encrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res>; +} - // Version negotiation. - if version == 0 { - return Ok(( - Self { - packet_type: PacketType::VersionNegotiation, - dcid, - scid: Some(scid), - token: &[], - header_len: decoder.offset(), - data, - }, - &[], - )); +pub struct PacketNumberDecoder { + expected: u64, +} +impl PacketNumberDecoder { + pub fn new(largest_acknowledged: Option) -> Self { + Self { + expected: largest_acknowledged.map_or(0, |x| x + 1), } - - // Check that this is a long header from this version. - if version != QUIC_VERSION { - return Ok(( - Self { - packet_type: PacketType::OtherVersion, - dcid, - scid: Some(scid), - token: &[], - header_len: decoder.offset(), - data, - }, - &[], - )); - } - if (first & PACKET_BIT_FIXED_QUIC) != PACKET_BIT_FIXED_QUIC { - return Err(Error::InvalidPacket); - } - let packet_type = match (first >> 4) & 3 { - PACKET_TYPE_INITIAL => PacketType::Initial, - PACKET_TYPE_0RTT => PacketType::ZeroRtt, - PACKET_TYPE_HANDSHAKE => PacketType::Handshake, - PACKET_TYPE_RETRY => PacketType::Retry, - _ => unreachable!(), - }; - - // The type-specific code includes a token. This consumes the remainder of the packet. - let (token, header_len) = Self::decode_type_specific(&mut decoder, packet_type)?; - let end = data.len() - decoder.remaining(); - let (data, remainder) = data.split_at(end); - Ok(( - Self { - packet_type, - dcid, - scid: Some(scid), - token, - header_len, - data, - }, - remainder, - )) } - /// Validate the given packet as though it were a retry. - pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool { - if self.packet_type != PacketType::Retry { - return false; - } - let expansion = retry_expansion(); - if self.data.len() <= expansion { - return false; - } - let (header, tag) = self.data.split_at(self.data.len() - expansion); - let mut encoder = Encoder::with_capacity(self.data.len()); - encoder.encode_vec(1, odcid); - encoder.encode(header); - RETRY_AEAD - .try_with(|aead| -> bool { - let mut buf = vec![0; expansion]; - if let Ok(v) = aead.borrow().decrypt(0, &encoder, tag, &mut buf) { - v.is_empty() - } else { - false - } - }) - .unwrap_or_else(|e| { - qerror!("Unable to access Retry AEAD: {:?}", e); - false - }) - } - - pub fn is_valid_initial(&self) -> bool { - // Packet has to be an initial, with a DCID of 8 bytes, or a token. - // Assume that the Server class validates the token. - self.packet_type == PacketType::Initial - && (self.dcid().len() >= 8 || !self.token.is_empty()) - } - - pub fn packet_type(&self) -> PacketType { - self.packet_type - } - - pub fn dcid(&self) -> &ConnectionIdRef<'a> { - &self.dcid - } - - pub fn scid(&self) -> &ConnectionIdRef<'a> { - self.scid - .as_ref() - .expect("should only be called for long header packets") - } - - pub fn token(&self) -> &'a [u8] { - self.token - } - - fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber { + // TODO(mt) test this. It's a strict implementation of the spec, + // but that doesn't mean we shouldn't test it. + fn decode_pn(&self, pn: u64, w: usize) -> PacketNumber { let window = 1_u64 << (w * 8); - let candidate = (expected & !(window - 1)) | pn; - if candidate + (window / 2) <= expected { + let candidate = (self.expected & !(window - 1)) | pn; + if candidate + (window / 2) <= self.expected { candidate + window - } else if candidate > expected + (window / 2) { + } else if candidate > self.expected + (window / 2) { match candidate.checked_sub(window) { Some(pn_sub) => pn_sub, None => candidate, @@ -520,438 +226,523 @@ impl<'a> PublicPacket<'a> { candidate } } +} - /// Decrypt the header of the packet. - fn decrypt_header( - &self, - crypto: &mut CryptoDxState, - ) -> Res<(bool, PacketNumber, Vec, &'a [u8])> { - assert_ne!(self.packet_type, PacketType::Retry); - assert_ne!(self.packet_type, PacketType::VersionNegotiation); +fn encode_pnl(l: usize) -> u8 { + assert!(l <= 4); + (l - 1) as u8 +} - qtrace!( - "unmask hdr={}", - hex(&self.data[..self.header_len + SAMPLE_OFFSET]) - ); +fn decode_pnl(u: u8) -> usize { + assert!(u < 4); // This came from 2 bits + (u + 1) as usize +} - let sample_offset = self.header_len + SAMPLE_OFFSET; - let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE)) - { - crypto.compute_mask(sample) - } else { - Err(Error::NoMoreData) - }?; +/* + Short Header - // Un-mask the leading byte. - let bits = if self.packet_type == PacketType::Short { - PACKET_HP_MASK_SHORT - } else { - PACKET_HP_MASK_LONG - }; - let first_byte = self.data[0] ^ (mask[0] & bits); - let pn_len = usize::from((first_byte & 0x3) + 1); + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+ + |0|1|S|R|R|K|P P| + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Destination Connection ID (0..144) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Packet Number (8/16/24/32) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Protected Payload (*) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // Make a copy of the header to work on. - let mut hdrbytes = self.data[..self.header_len + pn_len].to_vec(); - hdrbytes[0] = first_byte; - // Unmask the PN. - let mut pn_encoded: u64 = 0; - for i in 0..pn_len { - hdrbytes[self.header_len + i] ^= mask[1 + i]; - pn_encoded <<= 8; - pn_encoded += u64::from(hdrbytes[self.header_len + i]); - } + Long Header - qtrace!("unmasked hdr={}", hex(&hdrbytes)); + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+ + |1|1|T T|X X X X| + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version (32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | DCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Destination Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - let key_phase = self.packet_type == PacketType::Short - && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE; - let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len); - Ok(( - key_phase, - pn, - hdrbytes, - &self.data[self.header_len + pn_len..], - )) - } + Handshake + +-+-+-+-+-+-+-+-+ + |1|1| 2 |R R|P P| + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version (32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | DCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Destination Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Length (i) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Packet Number (8/16/24/32) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Payload (*) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res { - let space = PNSpace::from(self.packet_type); - // This has to work in two stages because we need to remove header protection - // before picking the keys to use. - if let Some(rx) = crypto.rx_hp(space) { - // Note that this will dump early, which creates a side-channel. - // This is OK in this case because we the only reason this can - // fail is if the cryptographic module is bad or the packet is - // too small (which is public information). - let (key_phase, pn, header, body) = self.decrypt_header(rx)?; - qtrace!([rx], "decoded header: {:?}", header); - if let Some(rx) = crypto.rx(space, key_phase) { - let d = rx.decrypt(pn, &header, body)?; - // If this is the first packet ever successfully decrypted - // using `rx`, make sure to initiate a key update. - if rx.needs_update() { - crypto.key_update_received(release_at)?; - } - crypto.check_pn_overlap()?; - Ok(DecryptedPacket { - pt: self.packet_type, - pn, - data: d, - }) - } else { - Err(Error::DecryptError) + Retry + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+ + |1|1| 3 | Unused| + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Version (32) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | DCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Destination Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ODCID Len (8) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Original Destination Connection ID (0..160) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Retry Token (*) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +pub fn decode_packet_hdr(cid_parser: &dyn ConnectionIdDecoder, pd: &[u8]) -> Res { + macro_rules! d { + ($d:expr) => { + match $d { + Some(v) => v, + _ => return Err(Error::NoMoreData), } - } else { - Err(Error::DecryptError) + }; + } + + let mut p = PacketHdr::default(); + let mut d = Decoder::from(pd); + + // Get the type byte + p.tbyte = d!(d.decode_byte()); + if (p.tbyte & 0x80) == 0 { + if p.tbyte & 0x40 == 0 { + return Err(Error::InvalidPacket); } + + // Short Header. + p.tipe = PacketType::Short; + let cid = d!(cid_parser.decode_cid(&mut d)); + p.dcid = ConnectionId(cid.to_vec()); // TODO(mt) unnecessary copy + p.hdr_len = pd.len() - d.remaining(); + p.body_len = d.remaining(); + p.epoch = 3; // TODO(ekr@rtfm.com): Decode key phase bits. + return Ok(p); + } + + let version = d!(d.decode_uint(4)) as u32; + p.version = Some(version); + p.dcid = ConnectionId(d!(d.decode_vec(1)).to_vec()); + p.scid = Some(ConnectionId(d!(d.decode_vec(1)).to_vec())); + + if version == 0 { + let mut vns = vec![]; + while d.remaining() > 0 { + vns.push(d!(d.decode_uint(4)) as u32); + } + p.tipe = PacketType::VN(vns); + // No need to set hdr_length and body_length + // because we won't need them. + return Ok(p); + } else { + if p.tbyte & 0x40 == 0 { + return Err(Error::InvalidPacket); + } + + p.tipe = match (p.tbyte >> 4) & 0x3 { + // TODO(ekr@rtfm.com): Check the 0 bits. + PACKET_TYPE_INITIAL => { + p.epoch = 0; + PacketType::Initial(d!(d.decode_vvec()).to_vec()) // TODO(mt) unnecessary copy + } + PACKET_TYPE_0RTT => { + p.epoch = 1; + PacketType::ZeroRTT + } + PACKET_TYPE_HANDSHAKE => { + p.epoch = 2; + PacketType::Handshake + } + PACKET_TYPE_RETRY => { + let odcid = ConnectionId(d!(d.decode_vec(1)).to_vec()); // TODO(mt) unnecessary copy + let token = d.decode_remainder().to_vec(); // TODO(mt) unnecessary copy + p.tipe = PacketType::Retry { odcid, token }; + return Ok(p); + } + _ => unreachable!(), + }; + } + + p.body_len = usize::try_from(d!(d.decode_varint()))?; + if p.body_len > d.remaining() { + return Err(Error::InvalidPacket); + } + p.hdr_len = pd.len() - d.remaining(); + + Ok(p) +} + +pub fn decrypt_packet( + crypto: &dyn CryptoCtx, + pn: PacketNumberDecoder, + hdr: &mut PacketHdr, + pkt: &[u8], +) -> Res> { + assert!(!matches!( + hdr.tipe, + PacketType::Retry{..} | PacketType::VN(_) + )); + + // First remove the header protection. + let payload = &pkt[hdr.hdr_len..]; + + if payload.len() < (4 + SAMPLE_SIZE) { + return Err(Error::NoMoreData); + } + let mask = crypto.compute_mask(&payload[4..(SAMPLE_SIZE + 4)])?; + + // Now put together a raw header to work on. + let pn_len = decode_pnl((hdr.tbyte ^ mask[0]) & 0x3); + let mut hdrbytes = pkt[0..(hdr.hdr_len + pn_len)].to_vec(); + + qtrace!("unmask hdr={}", hex(&hdrbytes)); + // Un-mask the leading byte. + hdrbytes[0] ^= mask[0] + & match hdr.tipe { + PacketType::Short => 0x1f, + _ => 0x0f, + }; + + // Now unmask the PN. + let mut pn_encoded: u64 = 0; + for i in 0..pn_len { + hdrbytes[hdr.hdr_len + i] ^= mask[1 + i]; + pn_encoded <<= 8; + pn_encoded += u64::from(hdrbytes[hdr.hdr_len + i]); + } + qtrace!("unmasked hdr={}", hex(&hdrbytes)); + hdr.hdr_len += pn_len; + hdr.body_len -= pn_len; + + // Now call out to expand the PN. + hdr.pn = pn.decode_pn(pn_encoded, pn_len); + + // Finally, decrypt. + Ok(crypto.aead_decrypt( + hdr.pn, + &hdrbytes, + &pkt[hdr.hdr_len..hdr.hdr_len + hdr.body_len()], + )?) +} + +fn encode_packet_short(crypto: &dyn CryptoCtx, hdr: &PacketHdr, body: &[u8]) -> Vec { + let mut enc = Encoder::default(); + // Leading byte. + let pnl = pn_length(hdr.pn); + enc.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | encode_pnl(pnl)); + enc.encode(&hdr.dcid.0); + enc.encode_uint(pnl, hdr.pn); + + encrypt_packet(crypto, hdr, enc, body) +} + +pub fn encode_packet_vn(hdr: &PacketHdr) -> Vec { + let mut d = Encoder::default(); + let mut rand_byte: [u8; 1] = [0; 1]; + rand::thread_rng().fill(&mut rand_byte); + d.encode_byte(PACKET_BIT_LONG | rand_byte[0]); + d.encode_uint(4, 0_u64); // version + d.encode_vec(1, &hdr.dcid); + d.encode_vec(1, hdr.scid.as_ref().unwrap()); + if let PacketType::VN(vers) = &hdr.tipe { + for ver in vers { + d.encode_uint(4, *ver); + } + } else { + panic!("wrong packet type"); + } + d.into() +} + +/* Handle Initial, 0-RTT, Handshake. */ +fn encode_packet_long(crypto: &dyn CryptoCtx, hdr: &PacketHdr, body: &[u8]) -> Vec { + let mut enc = Encoder::default(); + + let pnl = pn_length(hdr.pn); + enc.encode_byte( + PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | hdr.tipe.code() << 4 | encode_pnl(pnl), + ); + enc.encode_uint(4, hdr.version.unwrap()); + enc.encode_vec(1, &*hdr.dcid); + enc.encode_vec(1, &*hdr.scid.as_ref().unwrap()); + + if let PacketType::Initial(token) = &hdr.tipe { + enc.encode_vvec(&token); + } + enc.encode_varint((pnl + body.len() + AUTH_TAG_LEN) as u64); + enc.encode_uint(pnl, hdr.pn); + + encrypt_packet(crypto, hdr, enc, body) +} + +fn encrypt_packet( + crypto: &dyn CryptoCtx, + hdr: &PacketHdr, + mut enc: Encoder, + body: &[u8], +) -> Vec { + let hdr_len = enc.len(); + // Encrypt the packet. This has too many copies. + let ct = crypto.aead_encrypt(hdr.pn, &enc, body).unwrap(); + enc.encode(&ct); + qtrace!("mask hdr={}", hex(&enc[0..hdr_len])); + let pn_start = hdr_len - pn_length(hdr.pn); + let mask = crypto + .compute_mask(&enc[pn_start + 4..pn_start + SAMPLE_SIZE + 4]) + .unwrap(); + enc[0] ^= mask[0] + & match hdr.tipe { + PacketType::Short => 0x1f, + _ => 0x0f, + }; + for i in 0..pn_length(hdr.pn) { + enc[pn_start + i] ^= mask[i + 1]; + } + qtrace!("masked hdr={}", hex(&enc[0..hdr_len])); + enc.into() +} + +// TODO(ekr@rtfm.com): Minimal packet number lengths. +fn pn_length(_pn: PacketNumber) -> usize { + 3 +} + +pub fn encode_retry(hdr: &PacketHdr) -> Vec { + let mut rand_byte: [u8; 1] = [0; 1]; + rand::thread_rng().fill(&mut rand_byte); + if let PacketType::Retry { odcid, token } = &hdr.tipe { + let mut enc = Encoder::default(); + let b0 = PACKET_BIT_LONG + | PACKET_BIT_FIXED_QUIC + | (PACKET_TYPE_RETRY << 4) + | (rand_byte[0] & 0xf); + enc.encode_byte(b0); + enc.encode_uint(4, hdr.version.unwrap()); + enc.encode_vec(1, &hdr.dcid); + enc.encode_vec(1, &hdr.scid.as_ref().unwrap()); + enc.encode_vec(1, odcid); + enc.encode(token); + enc.into() + } else { + unreachable!() } } -impl fmt::Debug for PublicPacket<'_> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{:?}: {} {}", - self.packet_type(), - hex(&self.data[..self.header_len]), - hex(&self.data[self.header_len..]) - ) - } -} - -pub struct DecryptedPacket { - pt: PacketType, - pn: PacketNumber, - data: Vec, -} - -impl DecryptedPacket { - pub fn packet_type(&self) -> PacketType { - self.pt - } - - pub fn pn(&self) -> PacketNumber { - self.pn - } -} - -impl Deref for DecryptedPacket { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - &self.data[..] +pub fn encode_packet(crypto: &dyn CryptoCtx, hdr: &PacketHdr, body: &[u8]) -> Vec { + match &hdr.tipe { + PacketType::Short => encode_packet_short(crypto, hdr, body), + PacketType::VN(_) => encode_packet_vn(hdr), + PacketType::Retry { .. } => encode_retry(hdr), + PacketType::Initial(..) | PacketType::ZeroRTT | PacketType::Handshake => { + encode_packet_long(crypto, hdr, body) + } } } #[cfg(test)] +#[allow(unused_variables)] mod tests { use super::*; - use crate::crypto::{CryptoDxState, CryptoStates}; - use crate::FixedConnectionIdManager; - use neqo_common::Encoder; - use test_fixture::{fixture_init, now}; + use neqo_common::matches; - const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; - const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5]; + const TEST_BODY: [u8; 6] = [0x01, 0x23, 0x45, 0x67, 0x89, 0x10]; - /// This is a connection ID manager, which is only used for decoding short header packets. - fn cid_mgr() -> FixedConnectionIdManager { - FixedConnectionIdManager::new(SERVER_CID.len()) - } + struct TestFixture {} - const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[ - 0x0d, 0x00, 0x00, 0x00, 0x00, 0x18, 0x41, 0x0a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03, 0xee, - 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd, 0xf7, - 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04, 0x5a, - 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, - 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14, 0x4e, - 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, 0xb9, - 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04, - ]; - const SAMPLE_INITIAL: &[u8] = &[ - 0xc9, 0xff, 0x00, 0x00, 0x19, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, - 0x00, 0x40, 0x74, 0x16, 0x8b, 0xf2, 0x2b, 0x70, 0x02, 0x59, 0x6f, 0x99, 0xae, 0x67, 0xab, - 0xf6, 0x5a, 0x58, 0x52, 0xf5, 0x4f, 0x58, 0xc3, 0x7c, 0x80, 0x86, 0x82, 0xe2, 0xe4, 0x04, - 0x92, 0xd8, 0xa3, 0x89, 0x9f, 0xb0, 0x4f, 0xc0, 0xaf, 0xe9, 0xaa, 0xbc, 0x87, 0x67, 0xb1, - 0x8a, 0x0a, 0xa4, 0x93, 0x53, 0x74, 0x26, 0x37, 0x3b, 0x48, 0xd5, 0x02, 0x21, 0x4d, 0xd8, - 0x56, 0xd6, 0x3b, 0x78, 0xce, 0xe3, 0x7b, 0xc6, 0x64, 0xb3, 0xfe, 0x86, 0xd4, 0x87, 0xac, - 0x7a, 0x77, 0xc5, 0x30, 0x38, 0xa3, 0xcd, 0x32, 0xf0, 0xb5, 0x00, 0x4d, 0x9f, 0x57, 0x54, - 0xc4, 0xf7, 0xf2, 0xd1, 0xf3, 0x5c, 0xf3, 0xf7, 0x11, 0x63, 0x51, 0xc9, 0x2b, 0x99, 0xc8, - 0xae, 0x58, 0x33, 0x22, 0x5c, 0xb5, 0x18, 0x55, 0x20, 0xd6, 0x1e, 0x68, 0xcf, 0x5f, - ]; + const AEAD_MASK: u8 = 0; - #[test] - fn sample_server_initial() { - fixture_init(); - let mut prot = CryptoDxState::test_default(); - - // The spec uses PN=1, but our crypto refuses to skip packet numbers. - // So burn an encryption: - let burn = prot.encrypt(0, &[], &[]).expect("burn OK"); - assert_eq!(burn.len(), prot.expansion()); - - let mut builder = PacketBuilder::long( - Encoder::new(), - PacketType::Initial, - &ConnectionId::from(&[][..]), - &ConnectionId::from(SERVER_CID), - ); - builder.initial_token(&[]); - builder.pn(1, 2); - builder.encode(&SAMPLE_INITIAL_PAYLOAD); - let packet = builder.build(&mut prot).expect("build"); - assert_eq!(&packet[..], SAMPLE_INITIAL); - } - - #[test] - fn decrypt_initial() { - const EXTRA: &[u8] = &[0xce; 33]; - - fixture_init(); - let mut padded = SAMPLE_INITIAL.to_vec(); - padded.extend_from_slice(EXTRA); - let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap(); - assert_eq!(packet.packet_type(), PacketType::Initial); - assert_eq!(&packet.dcid()[..], &[]); - assert_eq!(&packet.scid()[..], SERVER_CID); - assert!(packet.token().is_empty()); - assert_eq!(remainder, EXTRA); - - let decrypted = packet - .decrypt(&mut CryptoStates::test_default(), now()) - .unwrap(); - assert_eq!(decrypted.pn(), 1); - } - - const SAMPLE_SHORT: &[u8] = &[ - 0x4c, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x55, 0x80, 0x33, 0xda, 0x1a, 0x01, - 0x19, 0x47, 0x57, 0xe2, 0x23, 0xcf, 0xe8, 0xde, 0x58, 0xce, 0x8b, 0xab, 0xc5, 0x19, - ]; - const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3]; - - #[test] - fn build_short() { - fixture_init(); - let mut builder = - PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID)); - builder.pn(0, 1); - builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. - let packet = builder - .build(&mut CryptoDxState::test_default()) - .expect("build"); - assert_eq!(&packet[..], SAMPLE_SHORT); - } - - #[test] - fn decode_short() { - let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap(); - assert_eq!(packet.packet_type(), PacketType::Short); - assert!(remainder.is_empty()); - let decrypted = packet - .decrypt(&mut CryptoStates::test_default(), now()) - .unwrap(); - assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD); - } - - /// By telling the decoder that the connection ID is shorter than it really is, we get a decryption error. - #[test] - fn decode_short_bad_cid() { - fixture_init(); - let (packet, remainder) = PublicPacket::decode( - SAMPLE_SHORT, - &FixedConnectionIdManager::new(SERVER_CID.len() - 1), - ) - .unwrap(); - assert_eq!(packet.packet_type(), PacketType::Short); - assert!(remainder.is_empty()); - assert!(packet - .decrypt(&mut CryptoStates::test_default(), now()) - .is_err()); - } - - /// Saying that the connection ID is longer causes the initial decode to fail. - #[test] - fn decode_short_long_cid() { - assert!(PublicPacket::decode( - SAMPLE_SHORT, - &FixedConnectionIdManager::new(SERVER_CID.len() + 1) - ) - .is_err()); - } - - #[test] - fn build_two() { - fixture_init(); - let mut prot = CryptoDxState::test_default(); - let mut builder = PacketBuilder::long( - Encoder::new(), - PacketType::Handshake, - &ConnectionId::from(SERVER_CID), - &ConnectionId::from(CLIENT_CID), - ); - builder.pn(0, 1); - builder.encode(&[0; 3]); - let encoder = builder.build(&mut prot).expect("build"); - assert_eq!(encoder.len(), 45); - let first = encoder.clone(); - - let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID)); - builder.pn(1, 3); - builder.encode(&[0]); // Minimal size (packet number is big enough). - let encoder = builder.build(&mut prot).expect("build"); - assert_eq!( - &first[..], - &encoder[..first.len()], - "the first packet should be a prefix" - ); - assert_eq!(encoder.len(), 45 + 29); - } - - #[test] - fn build_abort() { - let mut builder = PacketBuilder::long( - Encoder::new(), - PacketType::Initial, - &ConnectionId::from(&[][..]), - &ConnectionId::from(SERVER_CID), - ); - builder.initial_token(&[]); - builder.pn(1, 2); - let encoder = builder.abort(); - assert!(encoder.is_empty()); - } - - const SAMPLE_RETRY: &[u8] = &[ - 0xff, 0xff, 0x00, 0x00, 0x19, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, - 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x1e, 0x5e, 0xc5, 0xb0, 0x14, 0xcb, 0xb1, 0xf0, 0xfd, 0x93, - 0xdf, 0x40, 0x48, 0xc4, 0x46, 0xa6, - ]; - const RETRY_TOKEN: &[u8] = b"token"; - - #[test] - fn build_retry() { - fixture_init(); - let retry = PacketBuilder::retry(&[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap(); - - let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap(); - assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); - assert!(remainder.is_empty()); - - // The builder adds randomness, which makes expectations hard. - // So only do a full check when that randomness matches up. - if retry[0] == SAMPLE_RETRY[0] { - assert_eq!(&retry, &SAMPLE_RETRY); - } else { - // Otherwise, just check that the header is OK. - assert_eq!(retry[0] & 0xf0, 0xf0); - let header_range = 1..retry.len() - 16; - assert_eq!(&retry[header_range.clone()], &SAMPLE_RETRY[header_range]); + impl TestFixture { + fn auth_tag(hdr: &[u8], body: &[u8]) -> [u8; AUTH_TAG_LEN] { + [0; AUTH_TAG_LEN] } } - #[test] - fn decode_retry() { - fixture_init(); - let (packet, remainder) = - PublicPacket::decode(SAMPLE_RETRY, &FixedConnectionIdManager::new(5)).unwrap(); - assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID))); - assert!(packet.dcid().is_empty()); - assert_eq!(&packet.scid()[..], SERVER_CID); - assert_eq!(packet.token(), RETRY_TOKEN); - assert!(remainder.is_empty()); - } + impl CryptoCtx for TestFixture { + fn compute_mask(&self, sample: &[u8]) -> Res> { + Ok(vec![0xa5, 0xa5, 0xa5, 0xa5, 0xa5]) + } - #[test] - fn build_retry_multiple() { - // Run the build_retry test a few times. - // This increases the chance that the full comparison happens. - for _ in 0..32 { - build_retry(); + fn aead_decrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { + let mut pt = body.to_vec(); + + for i in &mut pt { + *i ^= AEAD_MASK; + } + let pt_len = pt.len() - AUTH_TAG_LEN; + let at = TestFixture::auth_tag(hdr, &pt[0..pt_len]); + for i in 0..16 { + if at[i] != pt[pt_len + i] { + return Err(Error::DecryptError); + } + } + Ok(pt[0..pt_len].to_vec()) + } + + fn aead_encrypt(&self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res> { + let tag = TestFixture::auth_tag(hdr, body); + let mut enc = Encoder::with_capacity(body.len() + tag.len()); + enc.encode(body); + enc.encode(&tag); + for i in 0..enc.len() { + enc[i] ^= AEAD_MASK; + } + + Ok(enc.into()) } } - /// Check some packets that are clearly not valid Retry packets. - #[test] - fn invalid_retry() { - fixture_init(); - let cid_mgr = FixedConnectionIdManager::new(5); - let odcid = ConnectionId::from(CLIENT_CID); - - assert!(PublicPacket::decode(&[], &cid_mgr).is_err()); - - let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY, &cid_mgr).unwrap(); - assert!(remainder.is_empty()); - assert!(packet.is_valid_retry(&odcid)); - - let mut damaged_retry = SAMPLE_RETRY.to_vec(); - let last = damaged_retry.len() - 1; - damaged_retry[last] ^= 66; - let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); - assert!(remainder.is_empty()); - assert!(!packet.is_valid_retry(&odcid)); - - damaged_retry.truncate(last); - let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap(); - assert!(remainder.is_empty()); - assert!(!packet.is_valid_retry(&odcid)); - - // An invalid token should be rejected sooner. - damaged_retry.truncate(last - 4); - assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); - - damaged_retry.truncate(last - 1); - assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err()); - } - - const SAMPLE_VN: &[u8] = &[ - 0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08, - 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0xff, 0x00, 0x00, 0x19, 0x0a, 0x0a, 0x0a, - 0x0a, - ]; - - #[test] - fn build_vn() { - fixture_init(); - let mut vn = PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID); - // Erase randomness from greasing... - assert_eq!(vn.len(), SAMPLE_VN.len()); - vn[0] &= 0x80; - for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) { - *v &= 0x0f; + impl ConnectionIdDecoder for TestFixture { + fn decode_cid(&self, dec: &mut Decoder) -> Option { + dec.decode(5).map(ConnectionId::from) } - assert_eq!(&vn, &SAMPLE_VN); + } + + fn default_hdr() -> PacketHdr { + PacketHdr { + tbyte: 0, + tipe: PacketType::Short, + version: Some(31), + dcid: ConnectionId(vec![1, 2, 3, 4, 5]), + scid: None, + pn: 0x0505, + epoch: 0, + hdr_len: 0, + body_len: 0, + } + } + + fn assert_headers_equal(left: &PacketHdr, right: &PacketHdr) { + assert_eq!(left.tipe, right.tipe); + assert_eq!(left.dcid, right.dcid); + assert_eq!(left.scid, right.scid); + assert_eq!(left.pn, right.pn); + } + + fn test_decrypt_packet(f: &TestFixture, packet: Vec) -> Res<(PacketHdr, Vec)> { + let mut phdr = decode_packet_hdr(f, &packet)?; + let body = decrypt_packet(f, PacketNumberDecoder::new(Some(0)), &mut phdr, &packet)?; + Ok((phdr, body)) + } + + fn test_encrypt_decrypt(f: &TestFixture, hdr: &mut PacketHdr, body: &[u8]) -> PacketHdr { + let packet = encode_packet(f, hdr, &TEST_BODY); + let res = test_decrypt_packet(&f, packet).unwrap(); + assert_headers_equal(&hdr, &res.0); + assert_eq!(body.to_vec(), res.1); + res.0 } #[test] - fn parse_vn() { - let (packet, remainder) = - PublicPacket::decode(SAMPLE_VN, &FixedConnectionIdManager::new(5)).unwrap(); - assert!(remainder.is_empty()); - assert_eq!(&packet.dcid[..], SERVER_CID); - assert!(packet.scid.is_some()); - assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID); + fn test_short_packet() { + let f = TestFixture {}; + let mut hdr = default_hdr(); + test_encrypt_decrypt(&f, &mut hdr, &TEST_BODY); } #[test] - fn decode_pn() { - // When the expected value is low, the value doesn't go negative. - assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0); - assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff); - assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0); - assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0); - assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100); - assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2); - assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff); - assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe); + fn test_short_packet_damaged() { + let f = TestFixture {}; + let hdr = default_hdr(); + let mut packet = encode_packet(&f, &hdr, &TEST_BODY); + let plen = packet.len(); + packet[plen - 1] ^= 0x7; + assert!(test_decrypt_packet(&f, packet).is_err()); + } - // This is invalid by spec, as we are expected to check for overflow around 2^62-1, - // but we don't need to worry about overflow - // and hitting this is basically impossible in practice. - assert_eq!( - PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4), - 0x4000_0000_0000_0002 - ); + #[test] + fn test_handshake_packet() { + let f = TestFixture {}; + let mut hdr = default_hdr(); + hdr.tipe = PacketType::Handshake; + hdr.scid = Some(ConnectionId(vec![9, 8, 7, 6, 5, 4, 3, 2])); + test_encrypt_decrypt(&f, &mut hdr, &TEST_BODY); + } + + #[test] + fn test_handshake_packet_damaged() { + let f = TestFixture {}; + let mut hdr = default_hdr(); + hdr.tipe = PacketType::Handshake; + hdr.scid = Some(ConnectionId(vec![9, 8, 7, 6, 5, 4, 3, 2])); + let mut packet = encode_packet(&f, &hdr, &TEST_BODY); + let plen = packet.len(); + packet[plen - 1] ^= 0x7; + assert!(test_decrypt_packet(&f, packet).is_err()); + } + + #[test] + fn test_initial_packet() { + let f = TestFixture {}; + let mut hdr = default_hdr(); + let tipe = PacketType::Initial(vec![0x0, 0x0, 0x0, 0x0]); + hdr.tipe = tipe; + hdr.scid = Some(ConnectionId(vec![9, 8, 7, 6, 5, 4, 3, 2])); + test_encrypt_decrypt(&f, &mut hdr, &TEST_BODY); + } + + #[test] + fn test_initial_packet_damaged() { + let f = TestFixture {}; + let mut hdr = default_hdr(); + hdr.tipe = PacketType::Initial(vec![0x0, 0x0, 0x0, 0x0]); + hdr.scid = Some(ConnectionId(vec![9, 8, 7, 6, 5, 4, 3, 2])); + let mut packet = encode_packet(&f, &hdr, &TEST_BODY); + let plen = packet.len(); + packet[plen - 1] ^= 0x7; + assert!(test_decrypt_packet(&f, packet).is_err()); + } + + #[test] + fn test_retry() { + let mut hdr = default_hdr(); + hdr.tipe = PacketType::Retry { + odcid: ConnectionId(vec![9, 8, 7, 6, 5, 4, 3, 2]), + token: vec![99, 88, 77, 66, 55, 44, 33], + }; + hdr.scid = Some(ConnectionId(vec![1, 2, 3, 4, 5])); + let packet = encode_retry(&hdr); + let f = TestFixture {}; + let decoded = decode_packet_hdr(&f, &packet).expect("should decode"); + assert_eq!(decoded.tipe, hdr.tipe); + assert_eq!(decoded.version, hdr.version); + assert_eq!(decoded.dcid, hdr.dcid); + assert_eq!(decoded.scid, hdr.scid); + } + + #[test] + fn generate_initial_cid() { + for i in 0..100 { + let cid = ConnectionId::generate_initial(); + if !matches!(cid.len(), 8..=20) { + panic!("connection ID {:?}", cid); + } + } } } diff --git a/third_party/rust/neqo-transport/src/recovery.rs b/third_party/rust/neqo-transport/src/recovery.rs index 7a932f7012a4..7e8bcf7c5aab 100644 --- a/third_party/rust/neqo-transport/src/recovery.rs +++ b/third_party/rust/neqo-transport/src/recovery.rs @@ -8,25 +8,33 @@ use std::cmp::{max, min}; use std::collections::BTreeMap; +use std::fmt::{self, Display}; use std::ops::{Index, IndexMut}; use std::time::{Duration, Instant}; use smallvec::SmallVec; -use neqo_common::{qdebug, qinfo, qtrace}; +use neqo_common::{const_max, const_min, qdebug, qinfo}; -use crate::cc::CongestionControl; use crate::crypto::CryptoRecoveryToken; use crate::flow_mgr::FlowControlRecoveryToken; use crate::send_stream::StreamRecoveryToken; -use crate::tracking::{AckToken, PNSpace, SentPacket}; -use crate::LOCAL_IDLE_TIMEOUT; +use crate::tracking::{AckToken, PNSpace}; const GRANULARITY: Duration = Duration::from_millis(20); // Defined in -recovery 6.2 as 500ms but using lower value until we have RTT // caching. See https://github.com/mozilla/neqo/issues/79 const INITIAL_RTT: Duration = Duration::from_millis(100); + const PACKET_THRESHOLD: u64 = 3; +pub const MAX_DATAGRAM_SIZE: usize = 1232; // For ipv6, smaller than ipv4 (1252) +pub const INITIAL_CWND_PKTS: usize = 10; +const INITIAL_WINDOW: usize = const_min( + INITIAL_CWND_PKTS * MAX_DATAGRAM_SIZE, + const_max(2 * MAX_DATAGRAM_SIZE, 14720), +); +pub const MIN_CONG_WINDOW: usize = MAX_DATAGRAM_SIZE * 2; +const PERSISTENT_CONG_THRESH: u32 = 3; #[derive(Debug, Clone)] pub enum RecoveryToken { @@ -34,7 +42,37 @@ pub enum RecoveryToken { Stream(StreamRecoveryToken), Crypto(CryptoRecoveryToken), Flow(FlowControlRecoveryToken), - HandshakeDone, +} + +#[derive(Debug, Clone)] +pub struct SentPacket { + ack_eliciting: bool, + time_sent: Instant, + pub tokens: Vec, + + time_declared_lost: Option, + + in_flight: bool, + size: usize, +} + +impl SentPacket { + pub fn new( + time_sent: Instant, + ack_eliciting: bool, + tokens: Vec, + size: usize, + in_flight: bool, + ) -> SentPacket { + SentPacket { + time_sent, + ack_eliciting, + tokens, + time_declared_lost: None, + size, + in_flight, + } + } } #[derive(Debug, Default)] @@ -80,14 +118,8 @@ impl RttVals { self.smoothed_rtt.unwrap_or(self.latest_rtt) } - fn pto(&self, pn_space: PNSpace) -> Duration { - self.rtt() - + max(4 * self.rttvar, GRANULARITY) - + if pn_space != PNSpace::ApplicationData { - Duration::from_millis(0) - } else { - self.max_ack_delay - } + fn pto(&self) -> Duration { + self.rtt() + max(4 * self.rttvar, GRANULARITY) + self.max_ack_delay } } @@ -98,8 +130,8 @@ pub(crate) struct LossRecoveryState { } impl LossRecoveryState { - fn new(mode: LossRecoveryMode, callback_time: Option) -> Self { - Self { + fn new(mode: LossRecoveryMode, callback_time: Option) -> LossRecoveryState { + LossRecoveryState { mode, callback_time, } @@ -112,28 +144,11 @@ impl LossRecoveryState { pub fn mode(&self) -> LossRecoveryMode { self.mode } - - pub fn get_pto_state(&mut self) -> Option<(PNSpace, bool)> { - if let LossRecoveryMode::PtoExpired { - dgram_available, - min_pn_space, - } = &mut self.mode - { - if *dgram_available > 0 { - *dgram_available -= 1; - Some((*min_pn_space, true)) - } else { - Some((*min_pn_space, false)) - } - } else { - None - } - } } impl Default for LossRecoveryState { - fn default() -> Self { - Self { + fn default() -> LossRecoveryState { + LossRecoveryState { mode: LossRecoveryMode::None, callback_time: None, } @@ -143,20 +158,15 @@ impl Default for LossRecoveryState { #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum LossRecoveryMode { None, - LostPacketsTimer, // lost packet timer is armed. - PtoTimer, // pto timer is armed - PtoExpired { - dgram_available: usize, - min_pn_space: PNSpace, - }, // pto expired, in this state we should send pto packets. + LostPackets, + PTO, } #[derive(Debug, Default)] pub(crate) struct LossRecoverySpace { + tx_pn: u64, largest_acked: Option, largest_acked_sent_time: Option, - time_of_last_sent_ack_eliciting_packet: Option, - ack_eliciting_outstanding: u64, sent_packets: BTreeMap, } @@ -174,43 +184,6 @@ impl LossRecoverySpace { earliest } - pub fn get_largest_acked(&self) -> Option { - self.largest_acked - } - - pub fn ack_eliciting_outstanding(&self) -> bool { - self.ack_eliciting_outstanding > 0 - } - - pub fn time_of_last_sent_ack_eliciting_packet(&self) -> Option { - if self.ack_eliciting_outstanding() { - debug_assert!(self.time_of_last_sent_ack_eliciting_packet.is_some()); - self.time_of_last_sent_ack_eliciting_packet - } else { - None - } - } - - pub fn on_packet_sent(&mut self, packet_number: u64, sent_packet: SentPacket) { - if sent_packet.ack_eliciting { - self.time_of_last_sent_ack_eliciting_packet = Some(sent_packet.time_sent); - self.ack_eliciting_outstanding += 1; - } - self.sent_packets.insert(packet_number, sent_packet); - } - - pub fn remove_packet(&mut self, pn: u64) -> Option { - if let Some(sent) = self.sent_packets.remove(&pn) { - if sent.ack_eliciting { - debug_assert!(self.ack_eliciting_outstanding > 0); - self.ack_eliciting_outstanding -= 1; - } - Some(sent) - } else { - None - } - } - // Remove all the acked packets. Returns them in ascending order -- largest // (i.e. highest PN) acked packet is last. fn remove_acked(&mut self, acked_ranges: Vec<(u64, u64)>) -> (Vec, bool) { @@ -219,7 +192,7 @@ impl LossRecoverySpace { for (end, start) in acked_ranges { // ^^ Notabug: see Frame::decode_ack_frame() for pn in start..=end { - if let Some(sent) = self.remove_packet(pn) { + if let Some(sent) = self.sent_packets.remove(&pn) { qdebug!("acked={}", pn); eliciting |= sent.ack_eliciting; acked_packets.insert(pn, sent); @@ -233,10 +206,11 @@ impl LossRecoverySpace { } /// Remove all tracked packets from the space. - /// This is called by a client when 0-RTT packets are dropped, when a Retry is received - /// and when keys are dropped. + /// This is called by a client when 0-RTT packets are dropped and when a Retry is received. fn remove_ignored(&mut self) -> impl Iterator { - self.ack_eliciting_outstanding = 0; + // The largest acknowledged or loss_time should still be unset. + // The client should not have received any ACK frames when it drops 0-RTT. + assert!(self.largest_acked.is_none()); std::mem::replace(&mut self.sent_packets, BTreeMap::default()) .into_iter() .map(|(_, v)| v) @@ -269,20 +243,182 @@ impl LossRecoverySpaces { } } +#[derive(Debug)] +struct CongestionControl { + congestion_window: usize, // = kInitialWindow + bytes_in_flight: usize, + congestion_recovery_start_time: Option, + ssthresh: usize, +} + +impl Default for CongestionControl { + fn default() -> Self { + CongestionControl { + congestion_window: INITIAL_WINDOW, + bytes_in_flight: 0, + congestion_recovery_start_time: None, + ssthresh: std::usize::MAX, + } + } +} + +impl Display for CongestionControl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "CongCtrl {}/{} ssthresh {}", + self.bytes_in_flight, self.congestion_window, self.ssthresh + ) + } +} + +impl CongestionControl { + #[cfg(test)] + pub fn cwnd(&self) -> usize { + self.congestion_window + } + + #[cfg(test)] + pub fn ssthresh(&self) -> usize { + self.ssthresh + } + + fn cwnd_avail(&self) -> usize { + // BIF can be higher than cwnd due to PTO packets, which are sent even + // if avail is 0, but still count towards BIF. + self.congestion_window.saturating_sub(self.bytes_in_flight) + } + + // Multi-packet version of OnPacketAckedCC + fn on_packets_acked(&mut self, acked_pkts: &[SentPacket]) { + for pkt in acked_pkts + .iter() + .filter(|pkt| pkt.in_flight) + .filter(|pkt| pkt.time_declared_lost.is_none()) + { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + + if self.in_congestion_recovery(pkt.time_sent) { + // Do not increase congestion window in recovery period. + continue; + } + if self.app_limited() { + // Do not increase congestion_window if application limited. + continue; + } + + if self.congestion_window < self.ssthresh { + self.congestion_window += pkt.size; + qinfo!([self], "slow start"); + } else { + self.congestion_window += (MAX_DATAGRAM_SIZE * pkt.size) / self.congestion_window; + qinfo!([self], "congestion avoidance"); + } + } + } + + fn on_packets_lost( + &mut self, + now: Instant, + largest_acked_sent: Option, + pto: Duration, + lost_packets: &[SentPacket], + ) { + if lost_packets.is_empty() { + return; + } + + for pkt in lost_packets.iter().filter(|pkt| pkt.in_flight) { + assert!(self.bytes_in_flight >= pkt.size); + self.bytes_in_flight -= pkt.size; + } + + qdebug!([self], "Pkts lost {}", lost_packets.len()); + + let last_lost_pkt = lost_packets.last().unwrap(); + self.on_congestion_event(now, last_lost_pkt.time_sent); + + let in_persistent_congestion = { + let congestion_period = pto * PERSISTENT_CONG_THRESH; + + match largest_acked_sent { + Some(las) => las < last_lost_pkt.time_sent - congestion_period, + None => { + // Nothing has ever been acked. Could still be PC. + let first_lost_pkt_sent = lost_packets.first().unwrap().time_sent; + last_lost_pkt.time_sent - first_lost_pkt_sent > congestion_period + } + } + }; + if in_persistent_congestion { + qinfo!([self], "persistent congestion"); + self.congestion_window = MIN_CONG_WINDOW; + } + } + + fn on_packet_sent(&mut self, pkt: &SentPacket) { + if !pkt.in_flight { + return; + } + + self.bytes_in_flight += pkt.size; + qdebug!( + [self], + "Pkt Sent len {}, bif {}, cwnd {}", + pkt.size, + self.bytes_in_flight, + self.congestion_window + ); + debug_assert!(self.bytes_in_flight <= self.congestion_window); + } + + fn in_congestion_recovery(&self, sent_time: Instant) -> bool { + self.congestion_recovery_start_time + .map(|start| sent_time <= start) + .unwrap_or(false) + } + + fn on_congestion_event(&mut self, now: Instant, sent_time: Instant) { + // Start a new congestion event if packet was sent after the + // start of the previous congestion recovery period. + if !self.in_congestion_recovery(sent_time) { + self.congestion_recovery_start_time = Some(now); + self.congestion_window /= 2; // kLossReductionFactor = 0.5 + self.congestion_window = max(self.congestion_window, MIN_CONG_WINDOW); + self.ssthresh = self.congestion_window; + qinfo!( + [self], + "Cong event -> recovery; cwnd {}, ssthresh {}", + self.congestion_window, + self.ssthresh + ); + } else { + qdebug!([self], "Cong event but already in recovery"); + } + } + + fn app_limited(&self) -> bool { + //TODO(agrover): how do we get this info?? + false + } +} + #[derive(Debug, Default)] pub(crate) struct LossRecovery { pto_count: u32, + time_of_last_sent_ack_eliciting_packet: Option, rtt_vals: RttVals, + cc: CongestionControl, enable_timed_loss_detection: bool, spaces: LossRecoverySpaces, - loss_recovery_state: LossRecoveryState, } impl LossRecovery { - pub fn new() -> Self { - Self { + pub fn new() -> LossRecovery { + LossRecovery { rtt_vals: RttVals { min_rtt: Duration::from_secs(u64::max_value()), max_ack_delay: Duration::from_millis(25), @@ -290,7 +426,7 @@ impl LossRecovery { ..RttVals::default() }, - ..Self::default() + ..LossRecovery::default() } } @@ -308,24 +444,28 @@ impl LossRecovery { self.cc.cwnd_avail() } + pub fn next_pn(&mut self, pn_space: PNSpace) -> u64 { + self.spaces[pn_space].tx_pn + } + + pub fn inc_pn(&mut self, pn_space: PNSpace) { + self.spaces[pn_space].tx_pn += 1; + } + + pub fn increment_pto_count(&mut self) { + self.pto_count += 1; + } + pub fn largest_acknowledged_pn(&self, pn_space: PNSpace) -> Option { self.spaces[pn_space].largest_acked } pub fn pto(&self) -> Duration { - self.rtt_vals.pto(PNSpace::ApplicationData) + self.rtt_vals.pto() } - pub fn drop_0rtt(&mut self) -> Vec { - // The largest acknowledged or loss_time should still be unset. - // The client should not have received any ACK frames when it drops 0-RTT. - assert!(self.spaces[PNSpace::ApplicationData] - .get_largest_acked() - .is_none()); - self.spaces[PNSpace::ApplicationData] - .remove_ignored() - .inspect(|p| self.cc.discard(&p)) - .collect() + pub fn drop_0rtt(&mut self) -> impl Iterator { + self.spaces[PNSpace::ApplicationData].remove_ignored() } pub fn on_packet_sent( @@ -335,8 +475,14 @@ impl LossRecovery { sent_packet: SentPacket, ) { qdebug!([self], "packet {:?}-{} sent.", pn_space, packet_number); + if sent_packet.ack_eliciting { + self.time_of_last_sent_ack_eliciting_packet = Some(sent_packet.time_sent); + } self.cc.on_packet_sent(&sent_packet); - self.spaces[pn_space].on_packet_sent(packet_number, sent_packet); + + self.spaces[pn_space] + .sent_packets + .insert(packet_number, sent_packet); } /// Returns (acked packets, lost packets) @@ -388,7 +534,7 @@ impl LossRecovery { self.cc.on_packets_lost( now, prev_largest_acked_sent_time, - self.rtt_vals.pto(pn_space), + self.rtt_vals.pto(), &lost_packets, ); @@ -406,32 +552,15 @@ impl LossRecovery { max(rtt * 9 / 8, GRANULARITY) } - // Calculate PTO duration - fn pto_timeout(&self, pn_space: PNSpace) -> Duration { - self.rtt_vals - .pto(pn_space) - .checked_mul(1 << self.pto_count) - .unwrap_or(LOCAL_IDLE_TIMEOUT * 2) - } - /// When receiving a retry, get all the sent packets so that they can be flushed. /// We also need to pretend that they never happened for the purposes of congestion control. pub fn retry(&mut self) -> Vec { - let cc = &mut self.cc; self.spaces .iter_mut() .flat_map(|spc| spc.remove_ignored()) - .inspect(|p| cc.discard(&p)) .collect() } - /// Discard state for a given packet number space. - pub fn discard(&mut self, pn_space: PNSpace) { - for p in self.spaces[pn_space].remove_ignored() { - self.cc.discard(&p); - } - } - /// Detect packets whose contents may need to be retransmitted. pub fn detect_lost_packets(&mut self, pn_space: PNSpace, now: Instant) -> Vec { self.enable_timed_loss_detection = false; @@ -503,7 +632,8 @@ impl LossRecovery { for pn in really_lost_pns { packet_space - .remove_packet(pn) + .sent_packets + .remove(&pn) .expect("PN must be in sent_packets"); } @@ -520,55 +650,52 @@ impl LossRecovery { lost_packets } - pub fn callback_time(&mut self) -> Option { - self.loss_recovery_state.callback_time() - } + pub fn get_timer(&mut self) -> LossRecoveryState { + qdebug!([self], "get_loss_detection_timer."); - #[cfg(test)] - pub fn state_mode(&self) -> LossRecoveryMode { - self.loss_recovery_state.mode() - } - - pub fn calculate_timer(&mut self) -> Option { - qtrace!([self], "get_loss_detection_timer."); - - let has_ack_eliciting_out = self.spaces.iter().any(|sp| sp.ack_eliciting_outstanding()); + let has_ack_eliciting_out = self + .spaces + .iter() + .flat_map(|spc| spc.sent_packets.values()) + .any(|sp| sp.ack_eliciting); qdebug!([self], "has_ack_eliciting_out={}", has_ack_eliciting_out,); if !has_ack_eliciting_out { - self.loss_recovery_state = LossRecoveryState::new(LossRecoveryMode::None, None); - return None; + return LossRecoveryState::new(LossRecoveryMode::None, None); } qinfo!( [self], - "sent packets init:({} ack_eliciting:{}), hs:({} ack_eliciting:{}), app:({} ack_eliciting:{})", + "sent packets {} {} {}", self.spaces[PNSpace::Initial].sent_packets.len(), - self.spaces[PNSpace::Initial].ack_eliciting_outstanding(), self.spaces[PNSpace::Handshake].sent_packets.len(), - self.spaces[PNSpace::Handshake].ack_eliciting_outstanding(), - self.spaces[PNSpace::ApplicationData].sent_packets.len(), - self.spaces[PNSpace::ApplicationData].ack_eliciting_outstanding() + self.spaces[PNSpace::ApplicationData].sent_packets.len() ); // QUIC only has one timer, but it does double duty because it falls // back to other uses if first use is not needed: first the loss // detection timer, and then the probe timeout (PTO). - self.loss_recovery_state = if let Some((_, earliest_time)) = self.get_earliest_loss_time() { - LossRecoveryState::new(LossRecoveryMode::LostPacketsTimer, Some(earliest_time)) + let (mode, maybe_timer) = if let Some((_, earliest_time)) = self.get_earliest_loss_time() { + (LossRecoveryMode::LostPackets, Some(earliest_time)) } else { - LossRecoveryState::new(LossRecoveryMode::PtoTimer, self.get_min_pto()) + // Calculate PTO duration + let timeout = self.rtt_vals.pto() * 2_u32.pow(self.pto_count); + ( + LossRecoveryMode::PTO, + self.time_of_last_sent_ack_eliciting_packet + .map(|i| i + timeout), + ) }; qdebug!( [self], "loss_detection_timer mode={:?} timer={:?}", - self.loss_recovery_state.mode(), - self.loss_recovery_state.callback_time() + mode, + maybe_timer ); - self.loss_recovery_state.callback_time() + LossRecoveryState::new(mode, maybe_timer) } /// Find when the earliest sent packet should be considered lost. @@ -590,87 +717,6 @@ impl LossRecovery { .min_by_key(|&(_, time)| time) .map(|(spc, val)| (spc, val + self.loss_delay())) } - - fn pto_time_for_pn(&self, pn_space: PNSpace) -> Option { - if let Some(time) = self.spaces[pn_space].time_of_last_sent_ack_eliciting_packet() { - Some(time + self.pto_timeout(pn_space)) - } else { - None - } - } - - /// Find when the last ack eliciting packet was sent. - pub fn get_min_pto(&self) -> Option { - // TODO ignore PNSpace::Application until handshake is done -> a server side problem. - PNSpace::iter() - .filter_map(|spc| self.pto_time_for_pn(*spc)) - .min_by_key(|&time| time) - } - - pub fn get_min_pto_pn_space(&self, now: Instant) -> Option { - PNSpace::iter() - .filter_map(|spc| { - if let Some(time) = self.pto_time_for_pn(*spc) { - if time <= now { - Some(*spc) - } else { - None - } - } else { - None - } - }) - .min_by_key(|&spc| spc) - } - - pub fn check_loss_detection_timeout(&mut self, now: Instant) -> Option> { - qdebug!([self], "check_loss_timeouts"); - - if self.loss_recovery_state.mode() == LossRecoveryMode::None { - // LR not the active timer - return None; - } - - if self.callback_time() > Some(now) { - // LR timer, but hasn't expired. - return None; - } - - // Timer expired and LR was active timer. - match self.loss_recovery_state.mode() { - LossRecoveryMode::None => unreachable!(), - LossRecoveryMode::LostPacketsTimer => { - // Time threshold loss detection - let (pn_space, _) = self - .get_earliest_loss_time() - .expect("must be sent packets if in LostPackets mode"); - return Some(self.detect_lost_packets(pn_space, now)); - } - LossRecoveryMode::PtoTimer => { - qinfo!( - [self], - "check_loss_detection_timeout -send_one_or_two_packets" - ); - - if let Some(min_pn_space) = self.get_min_pto_pn_space(now) { - self.loss_recovery_state = LossRecoveryState::new( - LossRecoveryMode::PtoExpired { - dgram_available: 1, - min_pn_space, - }, - Some(now), - ); - self.pto_count += 1; - } - } - _ => {} // We are already in PtoExpired state - } - None - } - - pub fn get_pto_state(&mut self) -> Option<(PNSpace, bool)> { - self.loss_recovery_state.get_pto_state() - } } impl ::std::fmt::Display for LossRecovery { @@ -927,11 +973,11 @@ mod tests { assert_sent_times(&lr, None, None, Some(pn1_sent_time)); // After time elapses, pn 1 is marked lost. - let callback_time = lr.calculate_timer(); + let lr_state = lr.get_timer(); let pn1_lost_time = pn1_sent_time + (INITIAL_RTT * 9 / 8); - assert_eq!(callback_time, Some(pn1_lost_time)); - match lr.state_mode() { - LossRecoveryMode::LostPacketsTimer => { + assert_eq!(lr_state.callback_time, Some(pn1_lost_time)); + match lr_state.mode { + LossRecoveryMode::LostPackets => { let packets = lr.detect_lost_packets(PNSpace::ApplicationData, pn1_lost_time); assert_eq!(packets.len(), 1) diff --git a/third_party/rust/neqo-transport/src/recv_stream.rs b/third_party/rust/neqo-transport/src/recv_stream.rs index 18f4b1a5a422..4b4703040675 100644 --- a/third_party/rust/neqo-transport/src/recv_stream.rs +++ b/third_party/rust/neqo-transport/src/recv_stream.rs @@ -29,7 +29,7 @@ pub(crate) type RecvStreams = BTreeMap; /// Holds data not yet read by application. Orders and dedupes data ranges /// from incoming STREAM frames. -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq)] pub struct RxStreamOrderer { data_ranges: BTreeMap>, // (start_offset, data) retired: u64, // Number of bytes the application has read @@ -274,9 +274,7 @@ impl RxStreamOrderer { } /// QUIC receiving states, based on -transport 3.2. -#[derive(Debug)] -#[allow(dead_code)] -// Because a dead_code warning is easier than clippy::unused_self, see https://github.com/rust-lang/rust/issues/68408 +#[derive(Debug, PartialEq)] enum RecvStreamState { Recv { recv_buf: RxStreamOrderer, @@ -297,7 +295,7 @@ enum RecvStreamState { impl RecvStreamState { fn new(max_bytes: u64) -> Self { - Self::Recv { + RecvStreamState::Recv { recv_buf: RxStreamOrderer::new(), max_bytes, max_stream_data: max_bytes, @@ -306,29 +304,34 @@ impl RecvStreamState { fn name(&self) -> &str { match self { - Self::Recv { .. } => "Recv", - Self::SizeKnown { .. } => "SizeKnown", - Self::DataRecvd { .. } => "DataRecvd", - Self::DataRead => "DataRead", - Self::ResetRecvd => "ResetRecvd", + RecvStreamState::Recv { .. } => "Recv", + RecvStreamState::SizeKnown { .. } => "SizeKnown", + RecvStreamState::DataRecvd { .. } => "DataRecvd", + RecvStreamState::DataRead => "DataRead", + RecvStreamState::ResetRecvd => "ResetRecvd", } } fn recv_buf(&self) -> Option<&RxStreamOrderer> { match self { - Self::Recv { recv_buf, .. } - | Self::SizeKnown { recv_buf, .. } - | Self::DataRecvd { recv_buf } => Some(recv_buf), - Self::DataRead | Self::ResetRecvd => None, + RecvStreamState::Recv { recv_buf, .. } + | RecvStreamState::SizeKnown { recv_buf, .. } + | RecvStreamState::DataRecvd { recv_buf } => Some(recv_buf), + RecvStreamState::DataRead | RecvStreamState::ResetRecvd => None, } } fn final_size(&self) -> Option { match self { - Self::SizeKnown { final_size, .. } => Some(*final_size), + RecvStreamState::SizeKnown { final_size, .. } => Some(*final_size), _ => None, } } + + fn transition(&mut self, new_state: Self) { + qtrace!("RecvStream state {} -> {}", self.name(), new_state.name()); + *self = new_state; + } } /// Implement a QUIC receive stream. @@ -355,27 +358,6 @@ impl RecvStream { } } - fn set_state(&mut self, new_state: RecvStreamState) { - debug_assert_ne!( - mem::discriminant(&self.state), - mem::discriminant(&new_state) - ); - qtrace!( - "RecvStream {} state {} -> {}", - self.stream_id.as_u64(), - self.state.name(), - new_state.name() - ); - - if let RecvStreamState::Recv { .. } = &self.state { - self.flow_mgr - .borrow_mut() - .clear_max_stream_data(self.stream_id) - } - - self.state = new_state; - } - pub fn inbound_stream_frame(&mut self, fin: bool, offset: u64, data: Vec) -> Res<()> { let new_end = offset + data.len() as u64; @@ -406,9 +388,10 @@ impl RecvStream { let buf = mem::replace(recv_buf, RxStreamOrderer::new()); if final_size == buf.retired() + buf.bytes_ready() as u64 { - self.set_state(RecvStreamState::DataRecvd { recv_buf: buf }); + self.state + .transition(RecvStreamState::DataRecvd { recv_buf: buf }); } else { - self.set_state(RecvStreamState::SizeKnown { + self.state.transition(RecvStreamState::SizeKnown { recv_buf: buf, final_size, }); @@ -424,7 +407,8 @@ impl RecvStream { recv_buf.inbound_frame(offset, data)?; if *final_size == recv_buf.retired() + recv_buf.bytes_ready() as u64 { let buf = mem::replace(recv_buf, RxStreamOrderer::new()); - self.set_state(RecvStreamState::DataRecvd { recv_buf: buf }); + self.state + .transition(RecvStreamState::DataRecvd { recv_buf: buf }); } } RecvStreamState::DataRecvd { .. } @@ -446,7 +430,7 @@ impl RecvStream { RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => { self.conn_events .recv_stream_reset(self.stream_id, application_error_code); - self.set_state(RecvStreamState::ResetRecvd); + self.state.transition(RecvStreamState::ResetRecvd); } _ => { // Ignore reset if in DataRecvd, DataRead, or ResetRecvd @@ -505,7 +489,7 @@ impl RecvStream { let bytes_read = recv_buf.read(buf)?; let fin_read = recv_buf.buffered() == 0; if fin_read { - self.set_state(RecvStreamState::DataRead) + self.state.transition(RecvStreamState::DataRead) } Ok((bytes_read, fin_read)) } @@ -519,10 +503,10 @@ impl RecvStream { qtrace!("stop_sending called when in state {}", self.state.name()); match &self.state { RecvStreamState::Recv { .. } | RecvStreamState::SizeKnown { .. } => { - self.set_state(RecvStreamState::ResetRecvd); + self.state.transition(RecvStreamState::ResetRecvd); self.flow_mgr.borrow_mut().stop_sending(self.stream_id, err) } - RecvStreamState::DataRecvd { .. } => self.set_state(RecvStreamState::DataRead), + RecvStreamState::DataRecvd { .. } => self.state.transition(RecvStreamState::DataRead), RecvStreamState::DataRead | RecvStreamState::ResetRecvd => { // Already in terminal state } @@ -533,8 +517,6 @@ impl RecvStream { #[cfg(test)] mod tests { use super::*; - use crate::frame::Frame; - use neqo_common::matches; #[test] fn test_stream_rx() { @@ -766,26 +748,4 @@ mod tests { assert_eq!(rx_ord.buffered(), 15); assert_eq!(rx_ord.retired(), 2); } - - #[test] - fn no_stream_flowc_event_after_exiting_recv() { - let flow_mgr = Rc::new(RefCell::new(FlowMgr::default())); - let conn_events = ConnectionEvents::default(); - - let frame1 = vec![0; RX_STREAM_DATA_WINDOW as usize]; - - let mut s = RecvStream::new( - 67.into(), - RX_STREAM_DATA_WINDOW, - Rc::clone(&flow_mgr), - conn_events, - ); - - s.inbound_stream_frame(false, 0, frame1).unwrap(); - flow_mgr.borrow_mut().max_stream_data(67.into(), 100); - assert!(matches!(s.flow_mgr.borrow().peek().unwrap(), Frame::MaxStreamData{..})); - s.inbound_stream_frame(true, RX_STREAM_DATA_WINDOW, vec![]) - .unwrap(); - assert!(matches!(s.flow_mgr.borrow().peek(), None)); - } } diff --git a/third_party/rust/neqo-transport/src/send_stream.rs b/third_party/rust/neqo-transport/src/send_stream.rs index 0d48118e8feb..903a5c9947bb 100644 --- a/third_party/rust/neqo-transport/src/send_stream.rs +++ b/third_party/rust/neqo-transport/src/send_stream.rs @@ -22,7 +22,6 @@ use crate::flow_mgr::FlowMgr; use crate::frame::{Frame, TxMode}; use crate::recovery::RecoveryToken; use crate::stream_id::StreamId; -use crate::tracking::PNSpace; use crate::{AppError, Error, Res}; #[derive(Debug, PartialEq, Clone, Copy)] @@ -284,17 +283,17 @@ impl TxBuffer { pub fn new() -> Self { Self { - send_buf: VecDeque::with_capacity(Self::BUFFER_SIZE), + send_buf: VecDeque::with_capacity(TxBuffer::BUFFER_SIZE), ..Self::default() } } /// Attempt to add some or all of the passed-in buffer to the TxBuffer. pub fn send(&mut self, buf: &[u8]) -> usize { - let can_buffer = min(Self::BUFFER_SIZE - self.buffered(), buf.len()); + let can_buffer = min(TxBuffer::BUFFER_SIZE - self.buffered(), buf.len()); if can_buffer > 0 { self.send_buf.extend(&buf[..can_buffer]); - assert!(self.send_buf.len() <= Self::BUFFER_SIZE); + assert!(self.send_buf.len() <= TxBuffer::BUFFER_SIZE); } can_buffer } @@ -363,7 +362,7 @@ impl TxBuffer { } fn avail(&self) -> usize { - Self::BUFFER_SIZE - self.buffered() + TxBuffer::BUFFER_SIZE - self.buffered() } pub fn highest_sent(&self) -> u64 { @@ -393,44 +392,60 @@ enum SendStreamState { impl SendStreamState { fn tx_buf(&self) -> Option<&TxBuffer> { match self { - Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf), - Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None, + SendStreamState::Send { send_buf } | SendStreamState::DataSent { send_buf, .. } => { + Some(send_buf) + } + SendStreamState::Ready + | SendStreamState::DataRecvd { .. } + | SendStreamState::ResetSent + | SendStreamState::ResetRecvd => None, } } fn tx_buf_mut(&mut self) -> Option<&mut TxBuffer> { match self { - Self::Send { send_buf } | Self::DataSent { send_buf, .. } => Some(send_buf), - Self::Ready | Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => None, + SendStreamState::Send { send_buf } | SendStreamState::DataSent { send_buf, .. } => { + Some(send_buf) + } + SendStreamState::Ready + | SendStreamState::DataRecvd { .. } + | SendStreamState::ResetSent + | SendStreamState::ResetRecvd => None, } } fn tx_avail(&self) -> u64 { match self { // In Ready, TxBuffer not yet allocated but size is known - Self::Ready => TxBuffer::BUFFER_SIZE.try_into().unwrap(), - Self::Send { send_buf } | Self::DataSent { send_buf, .. } => { + SendStreamState::Ready => TxBuffer::BUFFER_SIZE.try_into().unwrap(), + SendStreamState::Send { send_buf } | SendStreamState::DataSent { send_buf, .. } => { send_buf.avail().try_into().unwrap() } - Self::DataRecvd { .. } | Self::ResetSent | Self::ResetRecvd => 0, + SendStreamState::DataRecvd { .. } + | SendStreamState::ResetSent + | SendStreamState::ResetRecvd => 0, } } fn final_size(&self) -> Option { match self { - Self::DataSent { final_size, .. } | Self::DataRecvd { final_size } => Some(*final_size), - Self::Ready | Self::Send { .. } | Self::ResetSent | Self::ResetRecvd => None, + SendStreamState::DataSent { final_size, .. } + | SendStreamState::DataRecvd { final_size } => Some(*final_size), + SendStreamState::Ready + | SendStreamState::Send { .. } + | SendStreamState::ResetSent + | SendStreamState::ResetRecvd => None, } } fn name(&self) -> &str { match self { - Self::Ready => "Ready", - Self::Send { .. } => "Send", - Self::DataSent { .. } => "DataSent", - Self::DataRecvd { .. } => "DataRecvd", - Self::ResetSent => "ResetSent", - Self::ResetRecvd => "ResetRecvd", + SendStreamState::Ready => "Ready", + SendStreamState::Send { .. } => "Send", + SendStreamState::DataSent { .. } => "DataSent", + SendStreamState::DataRecvd { .. } => "DataRecvd", + SendStreamState::ResetSent => "ResetSent", + SendStreamState::ResetRecvd => "ResetRecvd", } } @@ -744,11 +759,11 @@ impl SendStreams { pub(crate) fn get_frame( &mut self, - space: PNSpace, + epoch: u16, mode: TxMode, remaining: usize, ) -> Option<(Frame, Option)> { - if space != PNSpace::ApplicationData { + if epoch != 3 && epoch != 1 { return None; } @@ -759,11 +774,11 @@ impl SendStreams { Frame::new_stream(stream_id.as_u64(), offset, data, complete, remaining) { qdebug!( - "Stream {} sending bytes {}-{}, space {:?}, mode {:?}", + "Stream {} sending bytes {}-{}, epoch {}, mode {:?}", stream_id.as_u64(), offset, offset + length as u64, - space, + epoch, mode, ); let fin = complete && length == data.len(); diff --git a/third_party/rust/neqo-transport/src/server.rs b/third_party/rust/neqo-transport/src/server.rs index 1279901a9928..0659b0732e46 100644 --- a/third_party/rust/neqo-transport/src/server.rs +++ b/third_party/rust/neqo-transport/src/server.rs @@ -15,10 +15,12 @@ use neqo_crypto::{ AntiReplay, }; -use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdManager, ConnectionIdRef}; -use crate::connection::{Connection, Output, State}; -use crate::packet::{PacketBuilder, PacketType, PublicPacket}; -use crate::Res; +use crate::connection::{Connection, ConnectionIdManager, Output, State}; +use crate::packet::{ + decode_packet_hdr, encode_packet_vn, encode_retry, ConnectionId, ConnectionIdDecoder, + PacketHdr, PacketType, Version, +}; +use crate::{Res, QUIC_VERSION}; use std::cell::RefCell; use std::collections::{HashMap, HashSet, VecDeque}; @@ -82,7 +84,7 @@ struct RetryToken { impl RetryToken { fn new(now: Instant) -> Res { - Ok(Self { + Ok(RetryToken { require_retry: false, self_encrypt: SelfEncrypt::new(TLS_VERSION_1_3, TLS_AES_128_GCM_SHA256)?, start_time: now, @@ -122,7 +124,7 @@ impl RetryToken { let end_millis = u32::try_from(end.duration_since(self.start_time).as_millis())?; token.encode_uint(4, end_millis); token.encode(dcid); - let peer_addr = Self::encode_peer_address(peer_address); + let peer_addr = RetryToken::encode_peer_address(peer_address); Ok(self.self_encrypt.seal(&peer_addr, &token)?) } @@ -138,7 +140,7 @@ impl RetryToken { peer_address: SocketAddr, now: Instant, ) -> Option { - let peer_addr = Self::encode_peer_address(peer_address); + let peer_addr = RetryToken::encode_peer_address(peer_address); let data = if let Ok(d) = self.self_encrypt.open(&peer_addr, token) { d } else { @@ -159,18 +161,22 @@ impl RetryToken { pub fn validate( &self, - token: &[u8], + hdr: &PacketHdr, peer_address: SocketAddr, now: Instant, ) -> RetryTokenResult { - if token.is_empty() { - if self.require_retry { - RetryTokenResult::Validate + if let PacketType::Initial(token) = &hdr.tipe { + if token.is_empty() { + if self.require_retry { + RetryTokenResult::Validate + } else { + RetryTokenResult::Pass + } + } else if let Some(cid) = self.decrypt_token(token, peer_address, now) { + RetryTokenResult::Valid(cid) } else { - RetryTokenResult::Pass + RetryTokenResult::Invalid } - } else if let Some(cid) = self.decrypt_token(token, peer_address, now) { - RetryTokenResult::Valid(cid) } else { RetryTokenResult::Invalid } @@ -178,6 +184,8 @@ impl RetryToken { } pub struct Server { + /// The version this server supports (currently just one). + version: Version, /// The names of certificates. certs: Vec, /// The ALPN values that the server supports. @@ -214,6 +222,7 @@ impl Server { cid_manager: CidMgr, ) -> Res { Ok(Self { + version: QUIC_VERSION, certs: certs.iter().map(|x| String::from(x.as_ref())).collect(), protocols: protocols.iter().map(|x| String::from(x.as_ref())).collect(), anti_replay, @@ -226,6 +235,20 @@ impl Server { }) } + fn create_vn(&self, hdr: &PacketHdr, received: Datagram) -> Datagram { + let vn = encode_packet_vn(&PacketHdr::new( + 0, + // Actual version we support and a greased value. + PacketType::VN(vec![self.version, 0xaaba_cada]), + Some(0), + hdr.scid.as_ref().unwrap().clone(), + Some(hdr.dcid.clone()), + 0, // unused + 0, // unused + )); + Datagram::new(received.destination(), received.source(), vn) + } + pub fn set_retry_required(&mut self, require_retry: bool) { self.retry.set_retry_required(require_retry); } @@ -273,8 +296,8 @@ impl Server { out.dgram() } - fn connection(&self, cid: &ConnectionIdRef) -> Option { - if let Some(c) = self.connections.borrow().get(&cid[..]) { + fn connection(&self, cid: &ConnectionId) -> Option { + if let Some(c) = self.connections.borrow().get(cid) { Some(c.clone()) } else { None @@ -283,35 +306,38 @@ impl Server { fn handle_initial( &mut self, - dcid: ConnectionId, - scid: ConnectionId, - token: Vec, + hdr: PacketHdr, dgram: Datagram, now: Instant, ) -> Option { - match self.retry.validate(&token, dgram.source(), now) { + match self.retry.validate(&hdr, dgram.source(), now) { RetryTokenResult::Invalid => None, RetryTokenResult::Pass => self.accept_connection(None, dgram, now), RetryTokenResult::Valid(dcid) => self.accept_connection(Some(dcid), dgram, now), RetryTokenResult::Validate => { - qinfo!([self], "Send retry for {:?}", dcid); + qinfo!([self], "Send retry for {:?}", hdr.dcid); - let res = self.retry.generate_token(&dcid, dgram.source(), now); + let res = self.retry.generate_token(&hdr.dcid, dgram.source(), now); let token = if let Ok(t) = res { t } else { qerror!([self], "unable to generate token, dropping packet"); return None; }; - let new_dcid = self.cid_manager.borrow_mut().generate_cid(); - let packet = PacketBuilder::retry(&scid, &new_dcid, &token, &dcid); - if let Ok(p) = packet { - let retry = Datagram::new(dgram.destination(), dgram.source(), p); - Some(retry) - } else { - qerror!([self], "unable to encode retry, dropping packet"); - None - } + let payload = encode_retry(&PacketHdr::new( + 0, // tbyte (unused on encode) + PacketType::Retry { + odcid: hdr.dcid.clone(), + token, + }, + Some(self.version), + hdr.scid.as_ref().unwrap().clone(), + Some(self.cid_manager.borrow_mut().generate_cid()), + 0, // Packet number + 0, // Epoch + )); + let retry = Datagram::new(dgram.destination(), dgram.source(), payload); + Some(retry) } } } @@ -354,9 +380,9 @@ impl Server { // This is only looking at the first packet header in the datagram. // All packets in the datagram are routed to the same connection. - let res = PublicPacket::decode(&dgram[..], self.cid_manager.borrow().as_decoder()); - let (packet, _remainder) = match res { - Ok(res) => res, + let res = decode_packet_hdr(self.cid_manager.borrow().as_decoder(), &dgram[..]); + let hdr = match res { + Ok(h) => h, _ => { qtrace!([self], "Discarding {:?}", dgram); return None; @@ -364,11 +390,11 @@ impl Server { }; // Finding an existing connection. Should be the most common case. - if let Some(c) = self.connection(packet.dcid()) { + if let Some(c) = self.connection(&hdr.dcid) { return self.process_connection(c, Some(dgram), now); } - if packet.packet_type() == PacketType::Short { + if hdr.tipe == PacketType::Short { // TODO send a stateless reset here. qtrace!([self], "Short header packet for an unknown connection"); return None; @@ -378,16 +404,12 @@ impl Server { qtrace!([self], "Bogus packet: too short"); return None; } - if packet.packet_type() == PacketType::OtherVersion { - let vn = PacketBuilder::version_negotiation(packet.scid(), packet.dcid()); - return Some(Datagram::new(dgram.destination(), dgram.source(), vn)); + + if hdr.version != Some(self.version) { + return Some(self.create_vn(&hdr, dgram)); } - // Copy values from `packet` because they are currently still borrowing from `dgram`. - let dcid = ConnectionId::from(packet.dcid()); - let scid = ConnectionId::from(packet.scid()); - let token = packet.token().to_vec(); - self.handle_initial(dcid, scid, token, dgram, now) + self.handle_initial(hdr, dgram, now) } /// Iterate through the pending connections looking for any that might want @@ -495,7 +517,7 @@ struct ServerConnectionIdManager { } impl ConnectionIdDecoder for ServerConnectionIdManager { - fn decode_cid<'a>(&self, dec: &mut Decoder<'a>) -> Option> { + fn decode_cid(&self, dec: &mut Decoder) -> Option { self.cid_manager.borrow_mut().decode_cid(dec) } } diff --git a/third_party/rust/neqo-transport/src/stats.rs b/third_party/rust/neqo-transport/src/stats.rs index 82ce131def51..92bde4260fd9 100644 --- a/third_party/rust/neqo-transport/src/stats.rs +++ b/third_party/rust/neqo-transport/src/stats.rs @@ -10,11 +10,9 @@ /// Connection statistics pub struct Stats { /// Total packets received - pub packets_rx: usize, + pub packets_rx: u64, /// Total packets sent - pub packets_tx: usize, + pub packets_tx: u64, /// Duplicate packets received - pub dups_rx: usize, - /// Dropped datagrams, or parts thereof - pub dropped_rx: usize, + pub dups_rx: u64, } diff --git a/third_party/rust/neqo-transport/src/tparams.rs b/third_party/rust/neqo-transport/src/tparams.rs index 291fbde8db43..d5ddcfca3b06 100644 --- a/third_party/rust/neqo-transport/src/tparams.rs +++ b/third_party/rust/neqo-transport/src/tparams.rs @@ -59,15 +59,15 @@ impl TransportParameter { fn encode(&self, enc: &mut Encoder, tipe: u16) { enc.encode_uint(2, tipe); match self { - Self::Bytes(a) => { + TransportParameter::Bytes(a) => { enc.encode_vec(2, a); } - Self::Integer(a) => { + TransportParameter::Integer(a) => { enc.encode_vec_with(2, |enc_inner| { enc_inner.encode_varint(*a); }); } - Self::Empty => { + TransportParameter::Empty => { enc.encode_uint(2, 0_u64); } }; @@ -85,12 +85,12 @@ impl TransportParameter { qtrace!("TP {:x} length {:x}", tipe, content.len()); let mut d = Decoder::from(content); let tp = match tipe { - ORIGINAL_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()), // TODO(mt) unnecessary copy + ORIGINAL_CONNECTION_ID => TransportParameter::Bytes(d.decode_remainder().to_vec()), // TODO(mt) unnecessary copy STATELESS_RESET_TOKEN => { if d.remaining() != 16 { return Err(Error::TransportParameterError); } - Self::Bytes(d.decode_remainder().to_vec()) // TODO(mt) unnecessary copy + TransportParameter::Bytes(d.decode_remainder().to_vec()) // TODO(mt) unnecessary copy } IDLE_TIMEOUT | INITIAL_MAX_DATA @@ -100,21 +100,21 @@ impl TransportParameter { | INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI | MAX_ACK_DELAY => match d.decode_varint() { - Some(v) => Self::Integer(v), + Some(v) => TransportParameter::Integer(v), None => return Err(Error::TransportParameterError), }, MAX_PACKET_SIZE => match d.decode_varint() { - Some(v) if v >= 1200 => Self::Integer(v), + Some(v) if v >= 1200 => TransportParameter::Integer(v), _ => return Err(Error::TransportParameterError), }, ACK_DELAY_EXPONENT => match d.decode_varint() { - Some(v) if v <= 20 => Self::Integer(v), + Some(v) if v <= 20 => TransportParameter::Integer(v), _ => return Err(Error::TransportParameterError), }, - DISABLE_MIGRATION => Self::Empty, + DISABLE_MIGRATION => TransportParameter::Empty, // Skip. _ => return Ok(None), }; @@ -315,7 +315,7 @@ impl ExtensionHandler for TransportParametersHandler { } fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult { - qtrace!( + qdebug!( "Handling transport parameters, msg={:?} value={}", msg, hex(d), diff --git a/third_party/rust/neqo-transport/src/tracking.rs b/third_party/rust/neqo-transport/src/tracking.rs index a93935895c93..a2779e7b3fea 100644 --- a/third_party/rust/neqo-transport/src/tracking.rs +++ b/third_party/rust/neqo-transport/src/tracking.rs @@ -13,21 +13,19 @@ use std::ops::{Index, IndexMut}; use std::time::{Duration, Instant}; use neqo_common::{qdebug, qinfo, qtrace, qwarn}; -use neqo_crypto::{Epoch, TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL}; +use neqo_crypto::constants::Epoch; use crate::frame::{AckRange, Frame}; -use crate::packet::{PacketNumber, PacketType}; use crate::recovery::RecoveryToken; // TODO(mt) look at enabling EnumMap for this: https://stackoverflow.com/a/44905797/1375574 -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Ord, Eq)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum PNSpace { - Initial, - Handshake, - ApplicationData, + Initial = 0, + Handshake = 1, + ApplicationData = 2, } -#[allow(clippy::use_self)] // https://github.com/rust-lang/rust-clippy/issues/3410 impl PNSpace { pub fn iter() -> impl Iterator { const SPACES: &[PNSpace] = &[ @@ -42,80 +40,17 @@ impl PNSpace { impl From for PNSpace { fn from(epoch: Epoch) -> Self { match epoch { - TLS_EPOCH_INITIAL => Self::Initial, - TLS_EPOCH_HANDSHAKE => Self::Handshake, - _ => Self::ApplicationData, + 0 => PNSpace::Initial, + 2 => PNSpace::Handshake, + _ => PNSpace::ApplicationData, } } } -impl From for PNSpace { - fn from(pt: PacketType) -> Self { - match pt { - PacketType::Initial => Self::Initial, - PacketType::Handshake => Self::Handshake, - PacketType::ZeroRtt | PacketType::Short => Self::ApplicationData, - _ => panic!("Attempted to get space from wrong packet type"), - } - } -} - -#[derive(Debug, Clone)] -pub struct SentPacket { - pub ack_eliciting: bool, - pub time_sent: Instant, - pub tokens: Vec, - - pub time_declared_lost: Option, - - pub in_flight: bool, - pub size: usize, -} - -impl SentPacket { - pub fn new( - time_sent: Instant, - ack_eliciting: bool, - tokens: Vec, - size: usize, - in_flight: bool, - ) -> Self { - Self { - time_sent, - ack_eliciting, - tokens, - time_declared_lost: None, - size, - in_flight, - } - } -} - -impl Into for PNSpace { - fn into(self) -> Epoch { - match self { - Self::Initial => TLS_EPOCH_INITIAL, - Self::Handshake => TLS_EPOCH_HANDSHAKE, - // Our epoch progresses forward, but the TLS epoch is fixed to 3. - Self::ApplicationData => TLS_EPOCH_APPLICATION_DATA, - } - } -} - -impl std::fmt::Display for PNSpace { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str(match self { - Self::Initial => "in", - Self::Handshake => "hs", - Self::ApplicationData => "ap", - }) - } -} - #[derive(Clone, Debug, Default)] pub struct PacketRange { - largest: PacketNumber, - smallest: PacketNumber, + largest: u64, + smallest: u64, ack_needed: bool, } @@ -209,7 +144,7 @@ pub struct RecvdPackets { space: PNSpace, ranges: VecDeque, /// The packet number of the lowest number packet that we are tracking. - min_tracked: PacketNumber, + min_tracked: u64, /// The time we got the largest acknowledged. largest_pn_time: Option, // The time that we should be sending an ACK. @@ -267,7 +202,7 @@ impl RecvdPackets { /// Add the packet to the tracked set. pub fn set_received(&mut self, now: Instant, pn: u64, ack_eliciting: bool) { let next_in_order_pn = self.ranges.front().map_or(0, |pr| pr.largest + 1); - qdebug!([self], "next in order pn: {}", next_in_order_pn); + qdebug!("next in order pn: {}", next_in_order_pn); let i = self.add(pn); // The new addition was the largest, so update the time we use for calculating ACK delay. @@ -302,7 +237,7 @@ impl RecvdPackets { } /// Check if the packet is a duplicate. - pub fn is_duplicate(&self, pn: PacketNumber) -> bool { + pub fn is_duplicate(&self, pn: u64) -> bool { if pn < self.min_tracked { return true; } @@ -333,7 +268,7 @@ impl RecvdPackets { impl ::std::fmt::Display for RecvdPackets { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { - write!(f, "Recvd-{}", self.space) + write!(f, "Recvd{:?}", self.space) } } @@ -368,9 +303,9 @@ impl AckTracker { pub(crate) fn get_frame( &mut self, now: Instant, - pn_space: PNSpace, + epoch: Epoch, ) -> Option<(Frame, Option)> { - let space = &mut self[pn_space]; + let space = &mut self[PNSpace::from(epoch)]; // Check that we aren't delaying ACKs. if !space.ack_now(now) { @@ -421,7 +356,7 @@ impl AckTracker { Some(( ack, Some(RecoveryToken::Ack(AckToken { - space: pn_space, + space: PNSpace::from(epoch), ranges, })), )) diff --git a/third_party/rust/neqo-transport/tests/conn_vectors.rs b/third_party/rust/neqo-transport/tests/conn_vectors.rs index 4277d3a24c60..3a01fd623e82 100644 --- a/third_party/rust/neqo-transport/tests/conn_vectors.rs +++ b/third_party/rust/neqo-transport/tests/conn_vectors.rs @@ -6,13 +6,11 @@ // Tests with the test vectors from the spec. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] - use neqo_common::{Datagram, Encoder}; use neqo_transport::State; use test_fixture::*; -const INITIAL_PACKET: &str = "c0ff000019088394c8f03e5157080000\ +const INITIAL_PACKET: &str = "c0ff000018088394c8f03e5157080000\ 449e3b343aa8535064a4268a0d9d7b1c\ 9d250ae355162276e9b1e3011ef6bbc0\ ab48ad5bcc2681e953857ca62becd752\ @@ -86,7 +84,7 @@ const INITIAL_PACKET: &str = "c0ff000019088394c8f03e5157080000\ d2bee680d8f41a597c262648bb18bcfc\ 13c8b3d97b1a77b2ac3af745d61a34cc\ 4709865bac824a94bb19058015e4e42d\ - aebe13f98ec51170a4aad0a8324bb768"; + 0488c1b9a230f7c894193cbb54ae795e"; #[test] fn process_client_initial() { diff --git a/third_party/rust/neqo-transport/tests/connection.rs b/third_party/rust/neqo-transport/tests/connection.rs index 61d60689d6a6..d2cd48900ca4 100644 --- a/third_party/rust/neqo-transport/tests/connection.rs +++ b/third_party/rust/neqo-transport/tests/connection.rs @@ -5,9 +5,9 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::use_self)] use neqo_common::Datagram; +use neqo_transport::State; use test_fixture::{self, default_client, default_server, now}; #[test] @@ -38,8 +38,8 @@ fn truncate_long_packet() { let dgram = client.process(None, now()).dgram(); assert!(dgram.is_some()); - assert!(client.state().connected()); + assert_eq!(*client.state(), State::Connected); let dgram = server.process(dgram, now()).dgram(); assert!(dgram.is_some()); - assert!(server.state().connected()); + assert_eq!(*server.state(), State::Connected); } diff --git a/third_party/rust/neqo-transport/tests/server.rs b/third_party/rust/neqo-transport/tests/server.rs index d51759d94830..e2c15024eb40 100644 --- a/third_party/rust/neqo-transport/tests/server.rs +++ b/third_party/rust/neqo-transport/tests/server.rs @@ -5,7 +5,6 @@ // except according to those terms. #![cfg_attr(feature = "deny-warnings", deny(warnings))] -#![warn(clippy::pedantic)] use neqo_common::{hex, matches, qdebug, qtrace, Datagram, Decoder, Encoder}; use neqo_crypto::{ @@ -43,7 +42,7 @@ fn default_server() -> Server { fn connected_server(server: &mut Server) -> ActiveConnectionRef { let server_connections = server.active_connections(); assert_eq!(server_connections.len(), 1); - assert_eq!(*server_connections[0].borrow().state(), State::Confirmed); + assert_eq!(*server_connections[0].borrow().state(), State::Connected); server_connections[0].clone() } @@ -68,13 +67,7 @@ fn connect(client: &mut Connection, server: &mut Server) -> ActiveConnectionRef assert!(dgram.is_some()); assert_eq!(*client.state(), State::Connected); let dgram = server.process(dgram, now()).dgram(); - assert!(dgram.is_some()); // ACK + HANDSHAKE_DONE + NST - - // Have the client process the HANDSHAKE_DONE. - let dgram = client.process(dgram, now()).dgram(); - assert!(dgram.is_none()); - assert_eq!(*client.state(), State::Confirmed); - + assert!(dgram.is_some()); // ACK + NST connected_server(server) } @@ -232,7 +225,7 @@ fn retry_after_initial() { } #[test] -fn retry_bad_integrity() { +fn retry_bad_odcid() { let mut server = default_server(); server.set_retry_required(true); let mut client = default_client(); @@ -245,9 +238,17 @@ fn retry_bad_integrity() { let retry = &dgram.as_ref().unwrap(); assertions::assert_retry(retry); - let mut tweaked = retry.to_vec(); - tweaked[retry.len() - 1] ^= 0x45; // damage the auth tag - let tweaked_packet = Datagram::new(retry.source(), retry.destination(), tweaked); + let mut dec = Decoder::new(retry); // Start after version. + dec.skip(5); + dec.skip_vec(1); // DCID + dec.skip_vec(1); // SCID + let odcid_len = dec.decode_byte().unwrap(); + assert_ne!(odcid_len, 0); + let odcid_offset = retry.len() - dec.remaining(); + assert!(odcid_offset < retry.len()); + let mut tweaked_retry = retry[..].to_vec(); + tweaked_retry[odcid_offset] ^= 0x45; // damage the ODCID + let tweaked_packet = Datagram::new(retry.source(), retry.destination(), tweaked_retry); // The client should ignore this packet. let dgram = client.process(Some(tweaked_packet), now()).dgram(); @@ -318,7 +319,6 @@ fn client_initial_aead_and_hp(dcid: &[u8]) -> (Aead, HpKey) { // at least 8 bytes long. Otherwise, the second Initial won't have a // long enough connection ID. #[test] -#[allow(clippy::shadow_unrelated)] fn mitm_retry() { let mut client = default_client(); let mut retry_server = default_server(); @@ -493,12 +493,8 @@ fn closed() { let mut client = default_client(); connect(&mut client, &mut server); - // The server will have sent a few things, so it will be on PTO. let res = server.process(None, now()); - assert!(res.callback() > Duration::new(0, 0)); - // The client will be on the delayed ACK timer. - let res = client.process(None, now()); - assert!(res.callback() > Duration::new(0, 0)); + assert_eq!(res, Output::Callback(Duration::from_secs(60))); qtrace!("60s later"); let res = server.process(None, now() + Duration::from_secs(60));