diff --git a/testing/web-platform/tests/tools/quic/__init__.py b/testing/web-platform/tests/tools/quic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/testing/web-platform/tests/tools/quic/commands.json b/testing/web-platform/tests/tools/quic/commands.json new file mode 100644 index 000000000000..044322eed3ba --- /dev/null +++ b/testing/web-platform/tests/tools/quic/commands.json @@ -0,0 +1,13 @@ +{ + "servequic": { + "path": "serve.py", + "script": "run", + "parser": "get_parser", + "py3only": true, + "help": "Start the QUIC server", + "virtualenv": true, + "requirements": [ + "requirements.txt" + ] + } +} diff --git a/testing/web-platform/tests/tools/quic/requirements.txt b/testing/web-platform/tests/tools/quic/requirements.txt new file mode 100644 index 000000000000..165260c78f7a --- /dev/null +++ b/testing/web-platform/tests/tools/quic/requirements.txt @@ -0,0 +1 @@ +aioquic==0.8.7 diff --git a/testing/web-platform/tests/tools/quic/serve.py b/testing/web-platform/tests/tools/quic/serve.py new file mode 100755 index 000000000000..9893f25e27bd --- /dev/null +++ b/testing/web-platform/tests/tools/quic/serve.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import argparse +import sys + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true", default=False, + help="turn on verbose logging") + return parser + + +def run(venv, **kwargs): + # TODO(Hexcles): Replace this with actual implementation. + print(sys.version) + assert sys.version_info.major == 3 + import aioquic + print('aioquic: ' + aioquic.__version__) + + +def main(): + kwargs = vars(get_parser().parse_args()) + return run(None, **kwargs) + + +if __name__ == '__main__': + main() diff --git a/testing/web-platform/tests/tools/third_party/aioquic/.appveyor.yml b/testing/web-platform/tests/tools/third_party/aioquic/.appveyor.yml new file mode 100644 index 000000000000..d0edee187ac3 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/.appveyor.yml @@ -0,0 +1,12 @@ +environment: + CIBW_SKIP: cp27-* cp33-* cp34-* + CIBW_TEST_COMMAND: python -m unittest discover -s {project}/tests +install: + - cmd: C:\Python36-x64\python.exe -m pip install cibuildwheel +build_script: + - cmd: C:\Python36-x64\python.exe -m cibuildwheel --output-dir wheelhouse + - ps: >- + if ($env:APPVEYOR_REPO_TAG -eq "true") { + Invoke-Expression "python -m pip install twine" + Invoke-Expression "python -m twine upload --skip-existing wheelhouse/*.whl" + } diff --git a/testing/web-platform/tests/tools/third_party/aioquic/.gitattributes b/testing/web-platform/tests/tools/third_party/aioquic/.gitattributes new file mode 100644 index 000000000000..0b5030c0d291 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/.gitattributes @@ -0,0 +1 @@ +*.bin binary diff --git a/testing/web-platform/tests/tools/third_party/aioquic/.github/workflows/tests.yml b/testing/web-platform/tests/tools/third_party/aioquic/.github/workflows/tests.yml new file mode 100644 index 000000000000..d19b21d74588 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/.github/workflows/tests.yml @@ -0,0 +1,128 @@ +name: tests + +on: [push, pull_request] + +jobs: + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install packages + run: pip install black flake8 isort mypy + - name: Run linters + run: | + flake8 examples src tests + isort -c -df -rc examples src tests + black --check --diff examples src tests + mypy examples src + + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python: [3.8, 3.7, 3.6] + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + - name: Disable firewall and configure compiler + if: matrix.os == 'macos-latest' + run: | + sudo /usr/libexec/ApplicationFirewall/socketfilterfw --setglobalstate off + echo "::set-env name=AIOQUIC_SKIP_TESTS::chacha20" + echo "::set-env name=CFLAGS::-I/usr/local/opt/openssl/include" + echo "::set-env name=LDFLAGS::-L/usr/local/opt/openssl/lib" + - name: Install OpenSSL + if: matrix.os == 'windows-latest' + run: | + choco install openssl --no-progress + echo "::set-env name=CL::/IC:\Progra~1\OpenSSL-Win64\include" + echo "::set-env name=LINK::/LIBPATH:C:\Progra~1\OpenSSL-Win64\lib" + - name: Run tests + run: | + pip install -U pip setuptools wheel + pip install coverage + pip install . + coverage run -m unittest discover -v + coverage xml + shell: bash + - name: Upload coverage report + uses: codecov/codecov-action@v1 + if: matrix.python != 'pypy3' + with: + token: ${{ secrets.CODECOV_TOKEN }} + + package-source: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Build source package + run: python setup.py sdist + - name: Upload source package + uses: actions/upload-artifact@v1 + with: + name: dist + path: dist/ + + package-wheel: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install nasm + if: matrix.os == 'windows-latest' + run: choco install -y nasm + - name: Install nmake + if: matrix.os == 'windows-latest' + run: | + & "C:\Program Files (x86)\Microsoft Visual Studio\Installer\vs_installer.exe" modify ` + --installPath "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise" ` + --add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 --passive --norestart + shell: powershell + - name: Build wheels + env: + CIBW_BEFORE_BUILD: scripts/build-openssl /tmp/vendor + CIBW_BEFORE_BUILD_WINDOWS: scripts\build-openssl.bat C:\cibw\vendor + CIBW_ENVIRONMENT: AIOQUIC_SKIP_TESTS=ipv6,loss CFLAGS=-I/tmp/vendor/include LDFLAGS=-L/tmp/vendor/lib + CIBW_ENVIRONMENT_WINDOWS: AIOQUIC_SKIP_TESTS=ipv6,loss CL="/IC:\cibw\vendor\include" LINK="/LIBPATH:C:\cibw\vendor\lib" + CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* pp27-* + CIBW_TEST_COMMAND: python -m unittest discover -t {project} -s {project}/tests + run: | + pip install cibuildwheel + cibuildwheel --output-dir dist + - name: Upload wheels + uses: actions/upload-artifact@v1 + with: + name: dist + path: dist/ + + publish: + runs-on: ubuntu-latest + needs: [lint, test, package-source, package-wheel] + steps: + - uses: actions/checkout@v2 + - uses: actions/download-artifact@v1 + with: + name: dist + path: dist/ + - name: Publish to PyPI + if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/') + uses: pypa/gh-action-pypi-publish@master + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/testing/web-platform/tests/tools/third_party/aioquic/.gitignore b/testing/web-platform/tests/tools/third_party/aioquic/.gitignore new file mode 100644 index 000000000000..17af00e727ef --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/.gitignore @@ -0,0 +1,10 @@ +*.egg-info +*.pyc +*.so +.coverage +.eggs +.mypy_cache +.vscode +/build +/dist +/docs/_build diff --git a/testing/web-platform/tests/tools/third_party/aioquic/LICENSE b/testing/web-platform/tests/tools/third_party/aioquic/LICENSE new file mode 100644 index 000000000000..049924a5aa08 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Jeremy Lainé. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of aioquic nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/testing/web-platform/tests/tools/third_party/aioquic/MANIFEST.in b/testing/web-platform/tests/tools/third_party/aioquic/MANIFEST.in new file mode 100644 index 000000000000..55e0903af994 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/MANIFEST.in @@ -0,0 +1,4 @@ +include LICENSE +recursive-include docs *.py *.rst Makefile +recursive-include examples *.html *.py +recursive-include tests *.bin *.pem *.py diff --git a/testing/web-platform/tests/tools/third_party/aioquic/README.rst b/testing/web-platform/tests/tools/third_party/aioquic/README.rst new file mode 100644 index 000000000000..2322042c24dd --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/README.rst @@ -0,0 +1,164 @@ +aioquic +======= + +|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |tests| |codecov| |black| + +.. |rtd| image:: https://readthedocs.org/projects/aioquic/badge/?version=latest + :target: https://aioquic.readthedocs.io/ + +.. |pypi-v| image:: https://img.shields.io/pypi/v/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |pypi-l| image:: https://img.shields.io/pypi/l/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |tests| image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg + :target: https://github.com/aiortc/aioquic/actions + +.. |codecov| image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg + :target: https://codecov.io/gh/aiortc/aioquic + +.. |black| image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/python/black + +What is ``aioquic``? +-------------------- + +``aioquic`` is a library for the QUIC network protocol in Python. It features +a minimal TLS 1.3 implementation, a QUIC stack and an HTTP/3 stack. + +QUIC standardisation is not finalised yet, but ``aioquic`` closely tracks the +specification drafts and is regularly tested for interoperability against other +`QUIC implementations`_. + +To learn more about ``aioquic`` please `read the documentation`_. + +Why should I use ``aioquic``? +----------------------------- + +``aioquic`` has been designed to be embedded into Python client and server +libraries wishing to support QUIC and / or HTTP/3. The goal is to provide a +common codebase for Python libraries in the hope of avoiding duplicated effort. + +Both the QUIC and the HTTP/3 APIs follow the "bring your own I/O" pattern, +leaving actual I/O operations to the API user. This approach has a number of +advantages including making the code testable and allowing integration with +different concurrency models. + +Features +-------- + +- QUIC stack conforming with draft-27 +- HTTP/3 stack conforming with draft-27 +- minimal TLS 1.3 implementation +- IPv4 and IPv6 support +- connection migration and NAT rebinding +- logging TLS traffic secrets +- logging QUIC events in QLOG format +- HTTP/3 server push support + +Requirements +------------ + +``aioquic`` requires Python 3.6 or better, and the OpenSSL development headers. + +Linux +..... + +On Debian/Ubuntu run: + +.. code-block:: console + + $ sudo apt install libssl-dev python3-dev + +On Alpine Linux you will also need the following: + +.. code-block:: console + + $ sudo apt install bsd-compat-headers libffi-dev + +OS X +.... + +On OS X run: + +.. code-block:: console + + $ brew install openssl + +You will need to set some environment variables to link against OpenSSL: + +.. code-block:: console + + $ export CFLAGS=-I/usr/local/opt/openssl/include + $ export LDFLAGS=-L/usr/local/opt/openssl/lib + +Windows +....... + +On Windows the easiest way to install OpenSSL is to use `Chocolatey`_. + +.. code-block:: console + + > choco install openssl + +You will need to set some environment variables to link against OpenSSL: + +.. code-block:: console + + > $Env:CL = "/IC:\Progra~1\OpenSSL-Win64\include" + > $Env:LINK = "/LIBPATH:C:\Progra~1\OpenSSL-Win64\lib" + +Running the examples +-------------------- + +After checking out the code using git you can run: + +.. code-block:: console + + $ pip install -e . + $ pip install aiofiles asgiref httpbin starlette wsproto + +HTTP/3 server +............. + +You can run the example server, which handles both HTTP/0.9 and HTTP/3: + +.. code-block:: console + + $ python examples/http3_server.py --certificate tests/ssl_cert.pem --private-key tests/ssl_key.pem + +HTTP/3 client +............. + +You can run the example client to perform an HTTP/3 request: + +.. code-block:: console + + $ python examples/http3_client.py --ca-certs tests/pycacert.pem https://localhost:4433/ + +Alternatively you can perform an HTTP/0.9 request: + +.. code-block:: console + + $ python examples/http3_client.py --ca-certs tests/pycacert.pem --legacy-http https://localhost:4433/ + +You can also open a WebSocket over HTTP/3: + +.. code-block:: console + + $ python examples/http3_client.py --ca-certs tests/pycacert.pem wss://localhost:4433/ws + +License +------- + +``aioquic`` is released under the `BSD license`_. + +.. _read the documentation: https://aioquic.readthedocs.io/en/latest/ +.. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations +.. _cryptography: https://cryptography.io/ +.. _Chocolatey: https://chocolatey.org/ +.. _BSD license: https://aioquic.readthedocs.io/en/latest/license.html diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/Makefile b/testing/web-platform/tests/tools/third_party/aioquic/docs/Makefile new file mode 100644 index 000000000000..3cb738b9617f --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = aioquic +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/asyncio.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/asyncio.rst new file mode 100644 index 000000000000..2ec568090146 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/asyncio.rst @@ -0,0 +1,32 @@ +asyncio API +=========== + +The asyncio API provides a high-level QUIC API built on top of :mod:`asyncio`, +Python's standard asynchronous I/O framework. + +``aioquic`` comes with a selection of examples, including: + +- an HTTP/3 client +- an HTTP/3 server + +The examples can be browsed on GitHub: + +https://github.com/aiortc/aioquic/tree/master/examples + +.. automodule:: aioquic.asyncio + +Client +------ + + .. autofunction:: connect + +Server +------ + + .. autofunction:: serve + +Common +------ + + .. autoclass:: QuicConnectionProtocol + :members: diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/conf.py b/testing/web-platform/tests/tools/third_party/aioquic/docs/conf.py new file mode 100644 index 000000000000..f89062ac9753 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/conf.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# aioquic documentation build configuration file, created by +# sphinx-quickstart on Thu Feb 8 17:22:14 2018. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys, os + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('..')) + + +class MockBuffer: + Buffer = None + BufferReadError = None + BufferWriteError = None + + +class MockCrypto: + AEAD = None + CryptoError = ValueError + HeaderProtection = None + + +class MockPylsqpack: + Decoder = None + Encoder = None + StreamBlocked = None + + +sys.modules.update({ + "aioquic._buffer": MockBuffer(), + "aioquic._crypto": MockCrypto(), + "pylsqpack": MockPylsqpack(), +}) + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx_autodoc_typehints', + 'sphinxcontrib.asyncio', +] +intersphinx_mapping = { + 'cryptography': ('https://cryptography.io/en/latest', None), + 'python': ('https://docs.python.org/3', None), +} + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'aioquic' +copyright = u'2019, Jeremy Lainé' +author = u'Jeremy Lainé' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '' +# The full version, including alpha/beta/rc tags. +release = '' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'alabaster' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'description': 'A library for QUIC in Python.', + 'github_button': True, + 'github_user': 'aiortc', + 'github_repo': 'aioquic', +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'about.html', + 'navigation.html', + 'relations.html', + 'searchbox.html', + ] +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'aioquicdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'aioquic.tex', 'aioquic Documentation', + author, 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'aioquic', 'aioquic Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'aioquic', 'aioquic Documentation', + author, 'aioquic', 'One line description of project.', + 'Miscellaneous'), +] diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/design.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/design.rst new file mode 100644 index 000000000000..7529e762f5a6 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/design.rst @@ -0,0 +1,37 @@ +Design +====== + +Sans-IO APIs +............ + +Both the QUIC and the HTTP/3 APIs follow the `sans I/O`_ pattern, leaving +actual I/O operations to the API user. This approach has a number of +advantages including making the code testable and allowing integration with +different concurrency models. + +TLS and encryption +.................. + +TLS 1.3 ++++++++ + +``aioquic`` features a minimal TLS 1.3 implementation built upon the +`cryptography`_ library. This is because QUIC requires some APIs which are +currently unavailable in mainstream TLS implementations such as OpenSSL: + +- the ability to extract traffic secrets + +- the ability to operate directly on TLS messages, without using the TLS + record layer + +Header protection and payload encryption +++++++++++++++++++++++++++++++++++++++++ + +QUIC makes extensive use of cryptographic operations to protect QUIC packet +headers and encrypt packet payloads. These operations occur for every single +packet and are a determining factor for performance. For this reason, they +are implemented as a C extension linked to `OpenSSL`_. + +.. _sans I/O: https://sans-io.readthedocs.io/ +.. _cryptography: https://cryptography.io/ +.. _OpenSSL: https://www.openssl.org/ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/h3.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/h3.rst new file mode 100644 index 000000000000..30e2201c43d4 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/h3.rst @@ -0,0 +1,44 @@ +HTTP/3 API +========== + +The HTTP/3 API performs no I/O on its own, leaving this to the API user. +This allows you to integrate HTTP/3 in any Python application, regardless of +the concurrency model you are using. + +Connection +---------- + +.. automodule:: aioquic.h3.connection + + .. autoclass:: H3Connection + :members: + + +Events +------ + +.. automodule:: aioquic.h3.events + + .. autoclass:: H3Event + :members: + + .. autoclass:: DataReceived + :members: + + .. autoclass:: HeadersReceived + :members: + + .. autoclass:: PushPromiseReceived + :members: + + +Exceptions +---------- + +.. automodule:: aioquic.h3.exceptions + + .. autoclass:: H3Error + :members: + + .. autoclass:: NoAvailablePushIDError + :members: diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/index.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/index.rst new file mode 100644 index 000000000000..2841561a2240 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/index.rst @@ -0,0 +1,39 @@ +aioquic +======= + +|pypi-v| |pypi-pyversions| |pypi-l| |tests| |codecov| + +.. |pypi-v| image:: https://img.shields.io/pypi/v/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |pypi-l| image:: https://img.shields.io/pypi/l/aioquic.svg + :target: https://pypi.python.org/pypi/aioquic + +.. |tests| image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg + :target: https://github.com/aiortc/aioquic/actions + +.. |codecov| image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg + :target: https://codecov.io/gh/aiortc/aioquic + +``aioquic`` is a library for the QUIC network protocol in Python. It features several +APIs: + +- a QUIC API following the "bring your own I/O" pattern, suitable for + embedding in any framework, + +- an HTTP/3 API which also follows the "bring your own I/O" pattern, + +- a QUIC convenience API built on top of :mod:`asyncio`, Python's standard asynchronous + I/O framework. + +.. toctree:: + :maxdepth: 2 + + design + quic + h3 + asyncio + license diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/license.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/license.rst new file mode 100644 index 000000000000..842d3b07fc93 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/license.rst @@ -0,0 +1,4 @@ +License +------- + +.. literalinclude:: ../LICENSE diff --git a/testing/web-platform/tests/tools/third_party/aioquic/docs/quic.rst b/testing/web-platform/tests/tools/third_party/aioquic/docs/quic.rst new file mode 100644 index 000000000000..60ebefb49d51 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/docs/quic.rst @@ -0,0 +1,51 @@ +QUIC API +======== + +The QUIC API performs no I/O on its own, leaving this to the API user. +This allows you to integrate QUIC in any Python application, regardless of +the concurrency model you are using. + +Connection +---------- + +.. automodule:: aioquic.quic.connection + + .. autoclass:: QuicConnection + :members: + + +Configuration +------------- + +.. automodule:: aioquic.quic.configuration + + .. autoclass:: QuicConfiguration + :members: + +.. automodule:: aioquic.quic.logger + + .. autoclass:: QuicLogger + :members: + +Events +------ + +.. automodule:: aioquic.quic.events + + .. autoclass:: QuicEvent + :members: + + .. autoclass:: ConnectionTerminated + :members: + + .. autoclass:: HandshakeCompleted + :members: + + .. autoclass:: PingAcknowledged + :members: + + .. autoclass:: StreamDataReceived + :members: + + .. autoclass:: StreamReset + :members: diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/demo.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/demo.py new file mode 100644 index 000000000000..fd4f9752f14c --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/demo.py @@ -0,0 +1,109 @@ +# +# demo application for http3_server.py +# + +import datetime +import os +from urllib.parse import urlencode + +import httpbin +from asgiref.wsgi import WsgiToAsgi +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse, Response +from starlette.staticfiles import StaticFiles +from starlette.templating import Jinja2Templates +from starlette.websockets import WebSocketDisconnect + +ROOT = os.path.dirname(__file__) +STATIC_ROOT = os.environ.get("STATIC_ROOT", os.path.join(ROOT, "htdocs")) +STATIC_URL = "/" +LOGS_PATH = os.path.join(STATIC_ROOT, "logs") +QVIS_URL = "https://qvis.edm.uhasselt.be/" + +templates = Jinja2Templates(directory=os.path.join(ROOT, "templates")) +app = Starlette() + + +@app.route("/") +async def homepage(request): + """ + Simple homepage. + """ + await request.send_push_promise("/style.css") + return templates.TemplateResponse("index.html", {"request": request}) + + +@app.route("/echo", methods=["POST"]) +async def echo(request): + """ + HTTP echo endpoint. + """ + content = await request.body() + media_type = request.headers.get("content-type") + return Response(content, media_type=media_type) + + +@app.route("/logs/?") +async def logs(request): + """ + Browsable list of QLOG files. + """ + logs = [] + for name in os.listdir(LOGS_PATH): + if name.endswith(".qlog"): + s = os.stat(os.path.join(LOGS_PATH, name)) + file_url = "https://" + request.headers["host"] + "/logs/" + name + logs.append( + { + "date": datetime.datetime.utcfromtimestamp(s.st_mtime).strftime( + "%Y-%m-%d %H:%M:%S" + ), + "file_url": file_url, + "name": name[:-5], + "qvis_url": QVIS_URL + + "?" + + urlencode({"file": file_url}) + + "#/sequence", + "size": s.st_size, + } + ) + return templates.TemplateResponse( + "logs.html", + { + "logs": sorted(logs, key=lambda x: x["date"], reverse=True), + "request": request, + }, + ) + + +@app.route("/{size:int}") +def padding(request): + """ + Dynamically generated data, maximum 50MB. + """ + size = min(50000000, request.path_params["size"]) + return PlainTextResponse("Z" * size) + + +@app.websocket_route("/ws") +async def ws(websocket): + """ + WebSocket echo endpoint. + """ + if "chat" in websocket.scope["subprotocols"]: + subprotocol = "chat" + else: + subprotocol = None + await websocket.accept(subprotocol=subprotocol) + + try: + while True: + message = await websocket.receive_text() + await websocket.send_text(message) + except WebSocketDisconnect: + pass + + +app.mount("/httpbin", WsgiToAsgi(httpbin.app)) + +app.mount(STATIC_URL, StaticFiles(directory=STATIC_ROOT, html=True)) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/robots.txt b/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/robots.txt new file mode 100644 index 000000000000..51aa0bdeaace --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/robots.txt @@ -0,0 +1,2 @@ +User-agent: * +Disallow: /logs diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/style.css b/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/style.css new file mode 100644 index 000000000000..01523db8ad8d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/htdocs/style.css @@ -0,0 +1,10 @@ +body { + font-family: Arial, sans-serif; + font-size: 16px; + margin: 0 auto; + width: 40em; +} + +table.logs { + width: 100%; +} diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_client.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_client.py new file mode 100644 index 000000000000..8a888a36ea7e --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_client.py @@ -0,0 +1,455 @@ +import argparse +import asyncio +import json +import logging +import os +import pickle +import ssl +import time +from collections import deque +from typing import Callable, Deque, Dict, List, Optional, Union, cast +from urllib.parse import urlparse + +import wsproto +import wsproto.events + +import aioquic +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h0.connection import H0_ALPN, H0Connection +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import ( + DataReceived, + H3Event, + HeadersReceived, + PushPromiseReceived, +) +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent +from aioquic.quic.logger import QuicLogger +from aioquic.tls import SessionTicket + +try: + import uvloop +except ImportError: + uvloop = None + +logger = logging.getLogger("client") + +HttpConnection = Union[H0Connection, H3Connection] + +USER_AGENT = "aioquic/" + aioquic.__version__ + + +class URL: + def __init__(self, url: str) -> None: + parsed = urlparse(url) + + self.authority = parsed.netloc + self.full_path = parsed.path + if parsed.query: + self.full_path += "?" + parsed.query + self.scheme = parsed.scheme + + +class HttpRequest: + def __init__( + self, method: str, url: URL, content: bytes = b"", headers: Dict = {} + ) -> None: + self.content = content + self.headers = headers + self.method = method + self.url = url + + +class WebSocket: + def __init__( + self, http: HttpConnection, stream_id: int, transmit: Callable[[], None] + ) -> None: + self.http = http + self.queue: asyncio.Queue[str] = asyncio.Queue() + self.stream_id = stream_id + self.subprotocol: Optional[str] = None + self.transmit = transmit + self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT) + + async def close(self, code=1000, reason="") -> None: + """ + Perform the closing handshake. + """ + data = self.websocket.send( + wsproto.events.CloseConnection(code=code, reason=reason) + ) + self.http.send_data(stream_id=self.stream_id, data=data, end_stream=True) + self.transmit() + + async def recv(self) -> str: + """ + Receive the next message. + """ + return await self.queue.get() + + async def send(self, message: str) -> None: + """ + Send a message. + """ + assert isinstance(message, str) + + data = self.websocket.send(wsproto.events.TextMessage(data=message)) + self.http.send_data(stream_id=self.stream_id, data=data, end_stream=False) + self.transmit() + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, HeadersReceived): + for header, value in event.headers: + if header == b"sec-websocket-protocol": + self.subprotocol = value.decode() + elif isinstance(event, DataReceived): + self.websocket.receive_data(event.data) + + for ws_event in self.websocket.events(): + self.websocket_event_received(ws_event) + + def websocket_event_received(self, event: wsproto.events.Event) -> None: + if isinstance(event, wsproto.events.TextMessage): + self.queue.put_nowait(event.data) + + +class HttpClient(QuicConnectionProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.pushes: Dict[int, Deque[H3Event]] = {} + self._http: Optional[HttpConnection] = None + self._request_events: Dict[int, Deque[H3Event]] = {} + self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} + self._websockets: Dict[int, WebSocket] = {} + + if self._quic.configuration.alpn_protocols[0].startswith("hq-"): + self._http = H0Connection(self._quic) + else: + self._http = H3Connection(self._quic) + + async def get(self, url: str, headers: Dict = {}) -> Deque[H3Event]: + """ + Perform a GET request. + """ + return await self._request( + HttpRequest(method="GET", url=URL(url), headers=headers) + ) + + async def post(self, url: str, data: bytes, headers: Dict = {}) -> Deque[H3Event]: + """ + Perform a POST request. + """ + return await self._request( + HttpRequest(method="POST", url=URL(url), content=data, headers=headers) + ) + + async def websocket(self, url: str, subprotocols: List[str] = []) -> WebSocket: + """ + Open a WebSocket. + """ + request = HttpRequest(method="CONNECT", url=URL(url)) + stream_id = self._quic.get_next_available_stream_id() + websocket = WebSocket( + http=self._http, stream_id=stream_id, transmit=self.transmit + ) + + self._websockets[stream_id] = websocket + + headers = [ + (b":method", b"CONNECT"), + (b":scheme", b"https"), + (b":authority", request.url.authority.encode()), + (b":path", request.url.full_path.encode()), + (b":protocol", b"websocket"), + (b"user-agent", USER_AGENT.encode()), + (b"sec-websocket-version", b"13"), + ] + if subprotocols: + headers.append( + (b"sec-websocket-protocol", ", ".join(subprotocols).encode()) + ) + self._http.send_headers(stream_id=stream_id, headers=headers) + + self.transmit() + + return websocket + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._request_events: + # http + self._request_events[event.stream_id].append(event) + if event.stream_ended: + request_waiter = self._request_waiter.pop(stream_id) + request_waiter.set_result(self._request_events.pop(stream_id)) + + elif stream_id in self._websockets: + # websocket + websocket = self._websockets[stream_id] + websocket.http_event_received(event) + + elif event.push_id in self.pushes: + # push + self.pushes[event.push_id].append(event) + + elif isinstance(event, PushPromiseReceived): + self.pushes[event.push_id] = deque() + self.pushes[event.push_id].append(event) + + def quic_event_received(self, event: QuicEvent) -> None: + #  pass event to the HTTP layer + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + async def _request(self, request: HttpRequest): + stream_id = self._quic.get_next_available_stream_id() + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.scheme.encode()), + (b":authority", request.url.authority.encode()), + (b":path", request.url.full_path.encode()), + (b"user-agent", USER_AGENT.encode()), + ] + + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], + ) + self._http.send_data(stream_id=stream_id, data=request.content, end_stream=True) + + waiter = self._loop.create_future() + self._request_events[stream_id] = deque() + self._request_waiter[stream_id] = waiter + self.transmit() + + return await asyncio.shield(waiter) + + +async def perform_http_request( + client: HttpClient, url: str, data: str, include: bool, output_dir: Optional[str], +) -> None: + # perform request + start = time.time() + if data is not None: + http_events = await client.post( + url, + data=data.encode(), + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + else: + http_events = await client.get(url) + elapsed = time.time() - start + + # print speed + octets = 0 + for http_event in http_events: + if isinstance(http_event, DataReceived): + octets += len(http_event.data) + logger.info( + "Received %d bytes in %.1f s (%.3f Mbps)" + % (octets, elapsed, octets * 8 / elapsed / 1000000) + ) + + # output response + if output_dir is not None: + output_path = os.path.join( + output_dir, os.path.basename(urlparse(url).path) or "index.html" + ) + with open(output_path, "wb") as output_file: + for http_event in http_events: + if isinstance(http_event, HeadersReceived) and include: + headers = b"" + for k, v in http_event.headers: + headers += k + b": " + v + b"\r\n" + if headers: + output_file.write(headers + b"\r\n") + elif isinstance(http_event, DataReceived): + output_file.write(http_event.data) + + +def save_session_ticket(ticket: SessionTicket) -> None: + """ + Callback which is invoked by the TLS engine when a new session ticket + is received. + """ + logger.info("New session ticket received") + if args.session_ticket: + with open(args.session_ticket, "wb") as fp: + pickle.dump(ticket, fp) + + +async def run( + configuration: QuicConfiguration, + urls: List[str], + data: str, + include: bool, + output_dir: Optional[str], + local_port: int, +) -> None: + # parse URL + parsed = urlparse(urls[0]) + assert parsed.scheme in ( + "https", + "wss", + ), "Only https:// or wss:// URLs are supported." + if ":" in parsed.netloc: + host, port_str = parsed.netloc.split(":") + port = int(port_str) + else: + host = parsed.netloc + port = 443 + + async with connect( + host, + port, + configuration=configuration, + create_protocol=HttpClient, + session_ticket_handler=save_session_ticket, + local_port=local_port, + ) as client: + client = cast(HttpClient, client) + + if parsed.scheme == "wss": + ws = await client.websocket(urls[0], subprotocols=["chat", "superchat"]) + + # send some messages and receive reply + for i in range(2): + message = "Hello {}, WebSocket!".format(i) + print("> " + message) + await ws.send(message) + + message = await ws.recv() + print("< " + message) + + await ws.close() + else: + # perform request + coros = [ + perform_http_request( + client=client, + url=url, + data=data, + include=include, + output_dir=output_dir, + ) + for url in urls + ] + await asyncio.gather(*coros) + + +if __name__ == "__main__": + defaults = QuicConfiguration(is_client=True) + + parser = argparse.ArgumentParser(description="HTTP/3 client") + parser.add_argument( + "url", type=str, nargs="+", help="the URL to query (must be HTTPS)" + ) + parser.add_argument( + "--ca-certs", type=str, help="load CA certificates from the specified file" + ) + parser.add_argument( + "-d", "--data", type=str, help="send the specified data in a POST request" + ) + parser.add_argument( + "-i", + "--include", + action="store_true", + help="include the HTTP response headers in the output", + ) + parser.add_argument( + "--max-data", + type=int, + help="connection-wide flow control limit (default: %d)" % defaults.max_data, + ) + parser.add_argument( + "--max-stream-data", + type=int, + help="per-stream flow control limit (default: %d)" % defaults.max_stream_data, + ) + parser.add_argument( + "-k", + "--insecure", + action="store_true", + help="do not validate server certificate", + ) + parser.add_argument("--legacy-http", action="store_true", help="use HTTP/0.9") + parser.add_argument( + "--output-dir", type=str, help="write downloaded files to this directory", + ) + parser.add_argument( + "-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format" + ) + parser.add_argument( + "-l", + "--secrets-log", + type=str, + help="log secrets to a file, for use with Wireshark", + ) + parser.add_argument( + "-s", + "--session-ticket", + type=str, + help="read and write session ticket from the specified file", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="increase logging verbosity" + ) + parser.add_argument( + "--local-port", type=int, default=0, help="local port to bind for connections", + ) + + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + if args.output_dir is not None and not os.path.isdir(args.output_dir): + raise Exception("%s is not a directory" % args.output_dir) + + # prepare configuration + configuration = QuicConfiguration( + is_client=True, alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN + ) + if args.ca_certs: + configuration.load_verify_locations(args.ca_certs) + if args.insecure: + configuration.verify_mode = ssl.CERT_NONE + if args.max_data: + configuration.max_data = args.max_data + if args.max_stream_data: + configuration.max_stream_data = args.max_stream_data + if args.quic_log: + configuration.quic_logger = QuicLogger() + if args.secrets_log: + configuration.secrets_log_file = open(args.secrets_log, "a") + if args.session_ticket: + try: + with open(args.session_ticket, "rb") as fp: + configuration.session_ticket = pickle.load(fp) + except FileNotFoundError: + pass + + if uvloop is not None: + uvloop.install() + loop = asyncio.get_event_loop() + try: + loop.run_until_complete( + run( + configuration=configuration, + urls=args.url, + data=args.data, + include=args.include, + output_dir=args.output_dir, + local_port=args.local_port, + ) + ) + finally: + if configuration.quic_logger is not None: + with open(args.quic_log, "w") as logger_fp: + json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_server.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_server.py new file mode 100644 index 000000000000..615b04b03c15 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/http3_server.py @@ -0,0 +1,484 @@ +import argparse +import asyncio +import importlib +import json +import logging +import os +import time +from collections import deque +from email.utils import formatdate +from typing import Callable, Deque, Dict, List, Optional, Union, cast + +import wsproto +import wsproto.events + +import aioquic +from aioquic.asyncio import QuicConnectionProtocol, serve +from aioquic.h0.connection import H0_ALPN, H0Connection +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import DataReceived, H3Event, HeadersReceived +from aioquic.h3.exceptions import NoAvailablePushIDError +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent +from aioquic.quic.logger import QuicLogger, QuicLoggerTrace +from aioquic.tls import SessionTicket + +try: + import uvloop +except ImportError: + uvloop = None + +AsgiApplication = Callable +HttpConnection = Union[H0Connection, H3Connection] + +SERVER_NAME = "aioquic/" + aioquic.__version__ + + +class HttpRequestHandler: + def __init__( + self, + *, + authority: bytes, + connection: HttpConnection, + protocol: QuicConnectionProtocol, + scope: Dict, + stream_ended: bool, + stream_id: int, + transmit: Callable[[], None], + ) -> None: + self.authority = authority + self.connection = connection + self.protocol = protocol + self.queue: asyncio.Queue[Dict] = asyncio.Queue() + self.scope = scope + self.stream_id = stream_id + self.transmit = transmit + + if stream_ended: + self.queue.put_nowait({"type": "http.request"}) + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, DataReceived): + self.queue.put_nowait( + { + "type": "http.request", + "body": event.data, + "more_body": not event.stream_ended, + } + ) + elif isinstance(event, HeadersReceived) and event.stream_ended: + self.queue.put_nowait( + {"type": "http.request", "body": b"", "more_body": False} + ) + + async def run_asgi(self, app: AsgiApplication) -> None: + await application(self.scope, self.receive, self.send) + + async def receive(self) -> Dict: + return await self.queue.get() + + async def send(self, message: Dict) -> None: + if message["type"] == "http.response.start": + self.connection.send_headers( + stream_id=self.stream_id, + headers=[ + (b":status", str(message["status"]).encode()), + (b"server", SERVER_NAME.encode()), + (b"date", formatdate(time.time(), usegmt=True).encode()), + ] + + [(k, v) for k, v in message["headers"]], + ) + elif message["type"] == "http.response.body": + self.connection.send_data( + stream_id=self.stream_id, + data=message.get("body", b""), + end_stream=not message.get("more_body", False), + ) + elif message["type"] == "http.response.push" and isinstance( + self.connection, H3Connection + ): + request_headers = [ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", self.authority), + (b":path", message["path"].encode()), + ] + [(k, v) for k, v in message["headers"]] + + # send push promise + try: + push_stream_id = self.connection.send_push_promise( + stream_id=self.stream_id, headers=request_headers + ) + except NoAvailablePushIDError: + return + + # fake request + cast(HttpServerProtocol, self.protocol).http_event_received( + HeadersReceived( + headers=request_headers, stream_ended=True, stream_id=push_stream_id + ) + ) + self.transmit() + + +class WebSocketHandler: + def __init__( + self, + *, + connection: HttpConnection, + scope: Dict, + stream_id: int, + transmit: Callable[[], None], + ) -> None: + self.closed = False + self.connection = connection + self.http_event_queue: Deque[DataReceived] = deque() + self.queue: asyncio.Queue[Dict] = asyncio.Queue() + self.scope = scope + self.stream_id = stream_id + self.transmit = transmit + self.websocket: Optional[wsproto.Connection] = None + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, DataReceived) and not self.closed: + if self.websocket is not None: + self.websocket.receive_data(event.data) + + for ws_event in self.websocket.events(): + self.websocket_event_received(ws_event) + else: + # delay event processing until we get `websocket.accept` + # from the ASGI application + self.http_event_queue.append(event) + + def websocket_event_received(self, event: wsproto.events.Event) -> None: + if isinstance(event, wsproto.events.TextMessage): + self.queue.put_nowait({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.Message): + self.queue.put_nowait({"type": "websocket.receive", "bytes": event.data}) + elif isinstance(event, wsproto.events.CloseConnection): + self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code}) + + async def run_asgi(self, app: AsgiApplication) -> None: + self.queue.put_nowait({"type": "websocket.connect"}) + + try: + await application(self.scope, self.receive, self.send) + finally: + if not self.closed: + await self.send({"type": "websocket.close", "code": 1000}) + + async def receive(self) -> Dict: + return await self.queue.get() + + async def send(self, message: Dict) -> None: + data = b"" + end_stream = False + if message["type"] == "websocket.accept": + subprotocol = message.get("subprotocol") + + self.websocket = wsproto.Connection(wsproto.ConnectionType.SERVER) + + headers = [ + (b":status", b"200"), + (b"server", SERVER_NAME.encode()), + (b"date", formatdate(time.time(), usegmt=True).encode()), + ] + if subprotocol is not None: + headers.append((b"sec-websocket-protocol", subprotocol.encode())) + self.connection.send_headers(stream_id=self.stream_id, headers=headers) + + # consume backlog + while self.http_event_queue: + self.http_event_received(self.http_event_queue.popleft()) + + elif message["type"] == "websocket.close": + if self.websocket is not None: + data = self.websocket.send( + wsproto.events.CloseConnection(code=message["code"]) + ) + else: + self.connection.send_headers( + stream_id=self.stream_id, headers=[(b":status", b"403")] + ) + end_stream = True + elif message["type"] == "websocket.send": + if message.get("text") is not None: + data = self.websocket.send( + wsproto.events.TextMessage(data=message["text"]) + ) + elif message.get("bytes") is not None: + data = self.websocket.send( + wsproto.events.Message(data=message["bytes"]) + ) + + if data: + self.connection.send_data( + stream_id=self.stream_id, data=data, end_stream=end_stream + ) + if end_stream: + self.closed = True + self.transmit() + + +Handler = Union[HttpRequestHandler, WebSocketHandler] + + +class HttpServerProtocol(QuicConnectionProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._handlers: Dict[int, Handler] = {} + self._http: Optional[HttpConnection] = None + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, HeadersReceived) and event.stream_id not in self._handlers: + authority = None + headers = [] + http_version = "0.9" if isinstance(self._http, H0Connection) else "3" + raw_path = b"" + method = "" + protocol = None + for header, value in event.headers: + if header == b":authority": + authority = value + headers.append((b"host", value)) + elif header == b":method": + method = value.decode() + elif header == b":path": + raw_path = value + elif header == b":protocol": + protocol = value.decode() + elif header and not header.startswith(b":"): + headers.append((header, value)) + + if b"?" in raw_path: + path_bytes, query_string = raw_path.split(b"?", maxsplit=1) + else: + path_bytes, query_string = raw_path, b"" + path = path_bytes.decode() + + # FIXME: add a public API to retrieve peer address + client_addr = self._http._quic._network_paths[0].addr + client = (client_addr[0], client_addr[1]) + + handler: Handler + scope: Dict + if method == "CONNECT" and protocol == "websocket": + subprotocols: List[str] = [] + for header, value in event.headers: + if header == b"sec-websocket-protocol": + subprotocols = [x.strip() for x in value.decode().split(",")] + scope = { + "client": client, + "headers": headers, + "http_version": http_version, + "method": method, + "path": path, + "query_string": query_string, + "raw_path": raw_path, + "root_path": "", + "scheme": "wss", + "subprotocols": subprotocols, + "type": "websocket", + } + handler = WebSocketHandler( + connection=self._http, + scope=scope, + stream_id=event.stream_id, + transmit=self.transmit, + ) + else: + extensions: Dict[str, Dict] = {} + if isinstance(self._http, H3Connection): + extensions["http.response.push"] = {} + scope = { + "client": client, + "extensions": extensions, + "headers": headers, + "http_version": http_version, + "method": method, + "path": path, + "query_string": query_string, + "raw_path": raw_path, + "root_path": "", + "scheme": "https", + "type": "http", + } + handler = HttpRequestHandler( + authority=authority, + connection=self._http, + protocol=self, + scope=scope, + stream_ended=event.stream_ended, + stream_id=event.stream_id, + transmit=self.transmit, + ) + self._handlers[event.stream_id] = handler + asyncio.ensure_future(handler.run_asgi(application)) + elif ( + isinstance(event, (DataReceived, HeadersReceived)) + and event.stream_id in self._handlers + ): + handler = self._handlers[event.stream_id] + handler.http_event_received(event) + + def quic_event_received(self, event: QuicEvent) -> None: + if isinstance(event, ProtocolNegotiated): + if event.alpn_protocol.startswith("h3-"): + self._http = H3Connection(self._quic) + elif event.alpn_protocol.startswith("hq-"): + self._http = H0Connection(self._quic) + elif isinstance(event, DatagramFrameReceived): + if event.data == b"quack": + self._quic.send_datagram_frame(b"quack-ack") + + #  pass event to the HTTP layer + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + +class SessionTicketStore: + """ + Simple in-memory store for session tickets. + """ + + def __init__(self) -> None: + self.tickets: Dict[bytes, SessionTicket] = {} + + def add(self, ticket: SessionTicket) -> None: + self.tickets[ticket.ticket] = ticket + + def pop(self, label: bytes) -> Optional[SessionTicket]: + return self.tickets.pop(label, None) + + +class QuicLoggerCustom(QuicLogger): + """ + Custom QUIC logger which writes one trace per file. + """ + + def __init__(self, path: str) -> None: + if not os.path.isdir(path): + raise ValueError("QUIC log output directory '%s' does not exist" % path) + self.path = path + super().__init__() + + def end_trace(self, trace: QuicLoggerTrace) -> None: + trace_dict = trace.to_dict() + trace_path = os.path.join( + self.path, trace_dict["common_fields"]["ODCID"] + ".qlog" + ) + with open(trace_path, "w") as logger_fp: + json.dump({"qlog_version": "draft-01", "traces": [trace_dict]}, logger_fp) + self._traces.remove(trace) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="QUIC server") + parser.add_argument( + "app", + type=str, + nargs="?", + default="demo:app", + help="the ASGI application as :", + ) + parser.add_argument( + "-c", + "--certificate", + type=str, + required=True, + help="load the TLS certificate from the specified file", + ) + parser.add_argument( + "--host", + type=str, + default="::", + help="listen on the specified address (defaults to ::)", + ) + parser.add_argument( + "--port", + type=int, + default=4433, + help="listen on the specified port (defaults to 4433)", + ) + parser.add_argument( + "-k", + "--private-key", + type=str, + required=True, + help="load the TLS private key from the specified file", + ) + parser.add_argument( + "-l", + "--secrets-log", + type=str, + help="log secrets to a file, for use with Wireshark", + ) + parser.add_argument( + "-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format" + ) + parser.add_argument( + "-r", + "--stateless-retry", + action="store_true", + help="send a stateless retry for new connections", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="increase logging verbosity" + ) + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + # import ASGI application + module_str, attr_str = args.app.split(":", maxsplit=1) + module = importlib.import_module(module_str) + application = getattr(module, attr_str) + + # create QUIC logger + if args.quic_log: + quic_logger = QuicLoggerCustom(args.quic_log) + else: + quic_logger = None + + # open SSL log file + if args.secrets_log: + secrets_log_file = open(args.secrets_log, "a") + else: + secrets_log_file = None + + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"], + is_client=False, + max_datagram_frame_size=65536, + quic_logger=quic_logger, + secrets_log_file=secrets_log_file, + ) + + # load SSL certificate and key + configuration.load_cert_chain(args.certificate, args.private_key) + + ticket_store = SessionTicketStore() + + if uvloop is not None: + uvloop.install() + loop = asyncio.get_event_loop() + loop.run_until_complete( + serve( + args.host, + args.port, + configuration=configuration, + create_protocol=HttpServerProtocol, + session_ticket_fetcher=ticket_store.pop, + session_ticket_handler=ticket_store.add, + stateless_retry=args.stateless_retry, + ) + ) + try: + loop.run_forever() + except KeyboardInterrupt: + pass diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/httpx_client.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/httpx_client.py new file mode 100644 index 000000000000..001e384e354b --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/httpx_client.py @@ -0,0 +1,213 @@ +import argparse +import asyncio +import json +import logging +import pickle +import sys +import time +from collections import deque +from typing import Deque, Dict, cast +from urllib.parse import urlparse + +from httpx import AsyncClient +from httpx.config import Timeout +from httpx.dispatch.base import AsyncDispatcher +from httpx.models import Request, Response + +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import DataReceived, H3Event, HeadersReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent +from aioquic.quic.logger import QuicLogger + +logger = logging.getLogger("client") + + +class H3Dispatcher(QuicConnectionProtocol, AsyncDispatcher): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._http = H3Connection(self._quic) + self._request_events: Dict[int, Deque[H3Event]] = {} + self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} + + async def send(self, request: Request, timeout: Timeout = None) -> Response: + stream_id = self._quic.get_next_available_stream_id() + + # prepare request + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.scheme.encode()), + (b":authority", str(request.url.authority).encode()), + (b":path", request.url.full_path.encode()), + ] + + [ + (k.encode(), v.encode()) + for (k, v) in request.headers.items() + if k not in ("connection", "host") + ], + ) + self._http.send_data(stream_id=stream_id, data=request.read(), end_stream=True) + + # transmit request + waiter = self._loop.create_future() + self._request_events[stream_id] = deque() + self._request_waiter[stream_id] = waiter + self.transmit() + + # process response + events: Deque[H3Event] = await asyncio.shield(waiter) + content = b"" + headers = [] + status_code = None + for event in events: + if isinstance(event, HeadersReceived): + for header, value in event.headers: + if header == b":status": + status_code = int(value.decode()) + elif header[0:1] != b":": + headers.append((header.decode(), value.decode())) + elif isinstance(event, DataReceived): + content += event.data + + return Response( + status_code=status_code, + http_version="HTTP/3", + headers=headers, + content=content, + request=request, + ) + + def http_event_received(self, event: H3Event): + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._request_events: + self._request_events[event.stream_id].append(event) + if event.stream_ended: + request_waiter = self._request_waiter.pop(stream_id) + request_waiter.set_result(self._request_events.pop(stream_id)) + + def quic_event_received(self, event: QuicEvent): + #  pass event to the HTTP layer + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + +def save_session_ticket(ticket): + """ + Callback which is invoked by the TLS engine when a new session ticket + is received. + """ + logger.info("New session ticket received") + if args.session_ticket: + with open(args.session_ticket, "wb") as fp: + pickle.dump(ticket, fp) + + +async def run(configuration: QuicConfiguration, url: str, data: str) -> None: + # parse URL + parsed = urlparse(url) + assert parsed.scheme == "https", "Only https:// URLs are supported." + if ":" in parsed.netloc: + host, port_str = parsed.netloc.split(":") + port = int(port_str) + else: + host = parsed.netloc + port = 443 + + async with connect( + host, + port, + configuration=configuration, + create_protocol=H3Dispatcher, + session_ticket_handler=save_session_ticket, + ) as dispatch: + client = AsyncClient(dispatch=cast(AsyncDispatcher, dispatch)) + + # perform request + start = time.time() + if data is not None: + response = await client.post( + url, + data=data.encode(), + headers={"content-type": "application/x-www-form-urlencoded"}, + ) + else: + response = await client.get(url) + + elapsed = time.time() - start + + # print speed + octets = len(response.content) + logger.info( + "Received %d bytes in %.1f s (%.3f Mbps)" + % (octets, elapsed, octets * 8 / elapsed / 1000000) + ) + + # print response + for header, value in response.headers.items(): + sys.stderr.write(header + ": " + value + "\r\n") + sys.stderr.write("\r\n") + sys.stdout.buffer.write(response.content) + sys.stdout.buffer.flush() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="HTTP/3 client") + parser.add_argument("url", type=str, help="the URL to query (must be HTTPS)") + parser.add_argument( + "-d", "--data", type=str, help="send the specified data in a POST request" + ) + parser.add_argument( + "-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format" + ) + parser.add_argument( + "-l", + "--secrets-log", + type=str, + help="log secrets to a file, for use with Wireshark", + ) + parser.add_argument( + "-s", + "--session-ticket", + type=str, + help="read and write session ticket from the specified file", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="increase logging verbosity" + ) + + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + # prepare configuration + configuration = QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN) + if args.quic_log: + configuration.quic_logger = QuicLogger() + if args.secrets_log: + configuration.secrets_log_file = open(args.secrets_log, "a") + if args.session_ticket: + try: + with open(args.session_ticket, "rb") as fp: + configuration.session_ticket = pickle.load(fp) + except FileNotFoundError: + pass + + loop = asyncio.get_event_loop() + try: + loop.run_until_complete( + run(configuration=configuration, url=args.url, data=args.data) + ) + finally: + if configuration.quic_logger is not None: + with open(args.quic_log, "w") as logger_fp: + json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/interop.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/interop.py new file mode 100644 index 000000000000..d6ad20b66b4b --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/interop.py @@ -0,0 +1,532 @@ +# +# !!! WARNING !!! +# +# This example uses some private APIs. +# + +import argparse +import asyncio +import json +import logging +import ssl +import time +from dataclasses import dataclass, field +from enum import Flag +from typing import Optional, cast + +import requests +import urllib3 +from http3_client import HttpClient + +from aioquic.asyncio import connect +from aioquic.h0.connection import H0_ALPN +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.logger import QuicLogger + + +class Result(Flag): + V = 0x000001 + H = 0x000002 + D = 0x000004 + C = 0x000008 + R = 0x000010 + Z = 0x000020 + S = 0x000040 + Q = 0x000080 + + M = 0x000100 + B = 0x000200 + A = 0x000400 + U = 0x000800 + P = 0x001000 + E = 0x002000 + L = 0x004000 + T = 0x008000 + + three = 0x010000 + d = 0x020000 + p = 0x040000 + + def __str__(self): + flags = sorted( + map( + lambda x: getattr(Result, x), + filter(lambda x: not x.startswith("_"), dir(Result)), + ), + key=lambda x: x.value, + ) + result_str = "" + for flag in flags: + if self & flag: + result_str += flag.name + else: + result_str += "-" + return result_str + + +@dataclass +class Server: + name: str + host: str + port: int = 4433 + http3: bool = True + retry_port: Optional[int] = 4434 + path: str = "/" + push_path: Optional[str] = None + result: Result = field(default_factory=lambda: Result(0)) + session_resumption_port: Optional[int] = None + structured_logging: bool = False + throughput_file_suffix: str = "" + verify_mode: Optional[int] = None + + +SERVERS = [ + Server("akamaiquic", "ietf.akaquic.com", port=443, verify_mode=ssl.CERT_NONE), + Server( + "aioquic", "quic.aiortc.org", port=443, push_path="/", structured_logging=True + ), + Server("ats", "quic.ogre.com"), + Server("f5", "f5quic.com", retry_port=4433), + Server("haskell", "mew.org", retry_port=4433), + Server("gquic", "quic.rocks", retry_port=None), + Server("lsquic", "http3-test.litespeedtech.com", push_path="/200?push=/100"), + Server( + "msquic", + "quic.westus.cloudapp.azure.com", + port=4433, + session_resumption_port=4433, + structured_logging=True, + throughput_file_suffix=".txt", + verify_mode=ssl.CERT_NONE, + ), + Server( + "mvfst", "fb.mvfst.net", port=443, push_path="/push", structured_logging=True + ), + Server("ngtcp2", "nghttp2.org", push_path="/?push=/100"), + Server("ngx_quic", "cloudflare-quic.com", port=443, retry_port=443), + Server("pandora", "pandora.cm.in.tum.de", verify_mode=ssl.CERT_NONE), + Server("picoquic", "test.privateoctopus.com", structured_logging=True), + Server("quant", "quant.eggert.org", http3=False), + Server("quic-go", "quic.seemann.io", port=443, retry_port=443), + Server("quiche", "quic.tech", port=8443, retry_port=8444), + Server("quicly", "quic.examp1e.net"), + Server("quinn", "ralith.com"), +] + + +async def test_version_negotiation(server: Server, configuration: QuicConfiguration): + # force version negotiation + configuration.supported_versions.insert(0, 0x1A2A3A4A) + + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + await protocol.ping() + + # check log + for stamp, category, event, data in configuration.quic_logger.to_dict()[ + "traces" + ][0]["events"]: + if ( + category == "transport" + and event == "packet_received" + and data["packet_type"] == "version_negotiation" + ): + server.result |= Result.V + + +async def test_handshake_and_close(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + await protocol.ping() + server.result |= Result.H + server.result |= Result.C + + +async def test_stateless_retry(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.retry_port, configuration=configuration + ) as protocol: + await protocol.ping() + + # check log + for stamp, category, event, data in configuration.quic_logger.to_dict()[ + "traces" + ][0]["events"]: + if ( + category == "transport" + and event == "packet_received" + and data["packet_type"] == "retry" + ): + server.result |= Result.S + + +async def test_quantum_readiness(server: Server, configuration: QuicConfiguration): + configuration.quantum_readiness_test = True + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + await protocol.ping() + server.result |= Result.Q + + +async def test_http_0(server: Server, configuration: QuicConfiguration): + if server.path is None: + return + + configuration.alpn_protocols = H0_ALPN + async with connect( + server.host, + server.port, + configuration=configuration, + create_protocol=HttpClient, + ) as protocol: + protocol = cast(HttpClient, protocol) + + # perform HTTP request + events = await protocol.get( + "https://{}:{}{}".format(server.host, server.port, server.path) + ) + if events and isinstance(events[0], HeadersReceived): + server.result |= Result.D + + +async def test_http_3(server: Server, configuration: QuicConfiguration): + if server.path is None: + return + + configuration.alpn_protocols = H3_ALPN + async with connect( + server.host, + server.port, + configuration=configuration, + create_protocol=HttpClient, + ) as protocol: + protocol = cast(HttpClient, protocol) + + # perform HTTP request + events = await protocol.get( + "https://{}:{}{}".format(server.host, server.port, server.path) + ) + if events and isinstance(events[0], HeadersReceived): + server.result |= Result.D + server.result |= Result.three + + # perform more HTTP requests to use QPACK dynamic tables + for i in range(2): + events = await protocol.get( + "https://{}:{}{}".format(server.host, server.port, server.path) + ) + if events and isinstance(events[0], HeadersReceived): + http = cast(H3Connection, protocol._http) + protocol._quic._logger.info( + "QPACK decoder bytes RX %d TX %d", + http._decoder_bytes_received, + http._decoder_bytes_sent, + ) + protocol._quic._logger.info( + "QPACK encoder bytes RX %d TX %d", + http._encoder_bytes_received, + http._encoder_bytes_sent, + ) + if ( + http._decoder_bytes_received + and http._decoder_bytes_sent + and http._encoder_bytes_received + and http._encoder_bytes_sent + ): + server.result |= Result.d + + # check push support + if server.push_path is not None: + protocol.pushes.clear() + await protocol.get( + "https://{}:{}{}".format(server.host, server.port, server.push_path) + ) + await asyncio.sleep(0.5) + for push_id, events in protocol.pushes.items(): + if ( + len(events) >= 3 + and isinstance(events[0], PushPromiseReceived) + and isinstance(events[1], HeadersReceived) + and isinstance(events[2], DataReceived) + ): + protocol._quic._logger.info( + "Push promise %d for %s received (status %s)", + push_id, + dict(events[0].headers)[b":path"].decode("ascii"), + int(dict(events[1].headers)[b":status"]), + ) + + server.result |= Result.p + + +async def test_session_resumption(server: Server, configuration: QuicConfiguration): + port = server.session_resumption_port or server.port + saved_ticket = None + + def session_ticket_handler(ticket): + nonlocal saved_ticket + saved_ticket = ticket + + # connect a first time, receive a ticket + async with connect( + server.host, + port, + configuration=configuration, + session_ticket_handler=session_ticket_handler, + ) as protocol: + await protocol.ping() + + # some servers don't send the ticket immediately + await asyncio.sleep(1) + + # connect a second time, with the ticket + if saved_ticket is not None: + configuration.session_ticket = saved_ticket + async with connect(server.host, port, configuration=configuration) as protocol: + await protocol.ping() + + # check session was resumed + if protocol._quic.tls.session_resumed: + server.result |= Result.R + + # check early data was accepted + if protocol._quic.tls.early_data_accepted: + server.result |= Result.Z + + +async def test_key_update(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + # cause some traffic + await protocol.ping() + + # request key update + protocol.request_key_update() + + # cause more traffic + await protocol.ping() + + server.result |= Result.U + + +async def test_migration(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + # cause some traffic + await protocol.ping() + + # change connection ID and replace transport + protocol.change_connection_id() + protocol._transport.close() + await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0)) + + # cause more traffic + await protocol.ping() + + # check log + dcids = set() + for stamp, category, event, data in configuration.quic_logger.to_dict()[ + "traces" + ][0]["events"]: + if ( + category == "transport" + and event == "packet_received" + and data["packet_type"] == "1RTT" + ): + dcids.add(data["header"]["dcid"]) + if len(dcids) == 2: + server.result |= Result.M + + +async def test_rebinding(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + # cause some traffic + await protocol.ping() + + # replace transport + protocol._transport.close() + await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0)) + + # cause more traffic + await protocol.ping() + + server.result |= Result.B + + +async def test_spin_bit(server: Server, configuration: QuicConfiguration): + async with connect( + server.host, server.port, configuration=configuration + ) as protocol: + for i in range(5): + await protocol.ping() + + # check log + spin_bits = set() + for stamp, category, event, data in configuration.quic_logger.to_dict()[ + "traces" + ][0]["events"]: + if category == "connectivity" and event == "spin_bit_updated": + spin_bits.add(data["state"]) + if len(spin_bits) == 2: + server.result |= Result.P + + +async def test_throughput(server: Server, configuration: QuicConfiguration): + failures = 0 + + for size in [5000000, 10000000]: + print("Testing %d bytes download" % size) + path = "/%d%s" % (size, server.throughput_file_suffix) + + # perform HTTP request over TCP + start = time.time() + response = requests.get("https://" + server.host + path, verify=False) + tcp_octets = len(response.content) + tcp_elapsed = time.time() - start + assert tcp_octets == size, "HTTP/TCP response size mismatch" + + # perform HTTP request over QUIC + if server.http3: + configuration.alpn_protocols = H3_ALPN + else: + configuration.alpn_protocols = H0_ALPN + start = time.time() + async with connect( + server.host, + server.port, + configuration=configuration, + create_protocol=HttpClient, + ) as protocol: + protocol = cast(HttpClient, protocol) + + http_events = await protocol.get( + "https://{}:{}{}".format(server.host, server.port, path) + ) + quic_elapsed = time.time() - start + quic_octets = 0 + for http_event in http_events: + if isinstance(http_event, DataReceived): + quic_octets += len(http_event.data) + assert quic_octets == size, "HTTP/QUIC response size mismatch" + + print(" - HTTP/TCP completed in %.3f s" % tcp_elapsed) + print(" - HTTP/QUIC completed in %.3f s" % quic_elapsed) + + if quic_elapsed > 1.1 * tcp_elapsed: + failures += 1 + print(" => FAIL") + else: + print(" => PASS") + + if failures == 0: + server.result |= Result.T + + +def print_result(server: Server) -> None: + result = str(server.result).replace("three", "3") + result = result[0:8] + " " + result[8:16] + " " + result[16:] + print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result)) + + +async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None: + for server in servers: + if server.structured_logging: + server.result |= Result.L + for test_name, test_func in tests: + print("\n=== %s %s ===\n" % (server.name, test_name)) + configuration = QuicConfiguration( + alpn_protocols=H3_ALPN + H0_ALPN, + is_client=True, + quic_logger=QuicLogger(), + secrets_log_file=secrets_log_file, + verify_mode=server.verify_mode, + ) + if test_name == "test_throughput": + timeout = 60 + else: + timeout = 10 + try: + await asyncio.wait_for( + test_func(server, configuration), timeout=timeout + ) + except Exception as exc: + print(exc) + + if quic_log: + with open("%s-%s.qlog" % (server.name, test_name), "w") as logger_fp: + json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4) + + print("") + print_result(server) + + # print summary + if len(servers) > 1: + print("SUMMARY") + for server in servers: + print_result(server) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="QUIC interop client") + parser.add_argument( + "-q", + "--quic-log", + action="store_true", + help="log QUIC events to a file in QLOG format", + ) + parser.add_argument( + "--server", type=str, help="only run against the specified server." + ) + parser.add_argument("--test", type=str, help="only run the specifed test.") + parser.add_argument( + "-l", + "--secrets-log", + type=str, + help="log secrets to a file, for use with Wireshark", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="increase logging verbosity" + ) + + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + # open SSL log file + if args.secrets_log: + secrets_log_file = open(args.secrets_log, "a") + else: + secrets_log_file = None + + # determine what to run + servers = SERVERS + tests = list(filter(lambda x: x[0].startswith("test_"), globals().items())) + if args.server: + servers = list(filter(lambda x: x.name == args.server, servers)) + if args.test: + tests = list(filter(lambda x: x[0] == args.test, tests)) + + # disable requests SSL warnings + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + loop = asyncio.get_event_loop() + loop.run_until_complete( + run( + servers=servers, + tests=tests, + quic_log=args.quic_log, + secrets_log_file=secrets_log_file, + ) + ) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/siduck_client.py b/testing/web-platform/tests/tools/third_party/aioquic/examples/siduck_client.py new file mode 100644 index 000000000000..1b6a5b76e0e7 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/siduck_client.py @@ -0,0 +1,100 @@ +import argparse +import asyncio +import json +import logging +import ssl +from typing import Optional, cast + +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import DatagramFrameReceived, QuicEvent +from aioquic.quic.logger import QuicLogger + +logger = logging.getLogger("client") + + +class SiduckClient(QuicConnectionProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ack_waiter: Optional[asyncio.Future[None]] = None + + async def quack(self) -> None: + assert self._ack_waiter is None, "Only one quack at a time." + self._quic.send_datagram_frame(b"quack") + + waiter = self._loop.create_future() + self._ack_waiter = waiter + self.transmit() + + return await asyncio.shield(waiter) + + def quic_event_received(self, event: QuicEvent) -> None: + if self._ack_waiter is not None: + if isinstance(event, DatagramFrameReceived) and event.data == b"quack-ack": + waiter = self._ack_waiter + self._ack_waiter = None + waiter.set_result(None) + + +async def run(configuration: QuicConfiguration, host: str, port: int) -> None: + async with connect( + host, port, configuration=configuration, create_protocol=SiduckClient + ) as client: + client = cast(SiduckClient, client) + logger.info("sending quack") + await client.quack() + logger.info("received quack-ack") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SiDUCK client") + parser.add_argument( + "host", type=str, help="The remote peer's host name or IP address" + ) + parser.add_argument("port", type=int, help="The remote peer's port number") + parser.add_argument( + "-k", + "--insecure", + action="store_true", + help="do not validate server certificate", + ) + parser.add_argument( + "-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format" + ) + parser.add_argument( + "-l", + "--secrets-log", + type=str, + help="log secrets to a file, for use with Wireshark", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="increase logging verbosity" + ) + + args = parser.parse_args() + + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + configuration = QuicConfiguration( + alpn_protocols=["siduck"], is_client=True, max_datagram_frame_size=65536 + ) + if args.insecure: + configuration.verify_mode = ssl.CERT_NONE + if args.quic_log: + configuration.quic_logger = QuicLogger() + if args.secrets_log: + configuration.secrets_log_file = open(args.secrets_log, "a") + + loop = asyncio.get_event_loop() + try: + loop.run_until_complete( + run(configuration=configuration, host=args.host, port=args.port) + ) + finally: + if configuration.quic_logger is not None: + with open(args.quic_log, "w") as logger_fp: + json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/index.html b/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/index.html new file mode 100644 index 000000000000..d085487e8408 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/index.html @@ -0,0 +1,31 @@ + + + + + aioquic + + + +

Welcome to aioquic

+

+ This is a test page for aioquic, + a QUIC and HTTP/3 implementation written in Python. +

+{% if request.scope["http_version"] == "3" %} +

+ Congratulations, you loaded this page using HTTP/3! +

+{% endif %} +

Available endpoints

+ + + diff --git a/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/logs.html b/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/logs.html new file mode 100644 index 000000000000..7099f67dca22 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/examples/templates/logs.html @@ -0,0 +1,28 @@ + + + + + aioquic - logs + + + +

QLOG files

+ + + + + + +{% for log in logs %} + + + + + +{% endfor %} +
namedate (UTC)size
+ {{ log.name }} + [qvis] + {{ log.date }}{{ log.size }}
+ + diff --git a/testing/web-platform/tests/tools/third_party/aioquic/requirements/doc.txt b/testing/web-platform/tests/tools/third_party/aioquic/requirements/doc.txt new file mode 100644 index 000000000000..6c33a397874c --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/requirements/doc.txt @@ -0,0 +1,3 @@ +cryptography +sphinx_autodoc_typehints +sphinxcontrib-asyncio diff --git a/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl b/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl new file mode 100755 index 000000000000..94ded72c1724 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl @@ -0,0 +1,29 @@ +#!/bin/sh + +set -e + +destdir=$1 +cachedir=$1.$PYTHON_ARCH + +for d in openssl $destdir; do + if [ -e $d ]; then + rm -rf $d + fi +done + +if [ ! -e $cachedir ]; then + # build openssl + mkdir openssl + curl -L https://www.openssl.org/source/openssl-1.1.1f.tar.gz | tar xz -C openssl --strip-components 1 + cd openssl + + ./config no-comp no-shared no-tests + make + + mkdir $cachedir + mkdir $cachedir/lib + cp -R include $cachedir + cp libcrypto.a libssl.a $cachedir/lib +fi + +cp -R $cachedir $destdir diff --git a/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl.bat b/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl.bat new file mode 100644 index 000000000000..610079ddf832 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/scripts/build-openssl.bat @@ -0,0 +1,39 @@ +set destdir=%1 +set cachedir=%1.%PYTHON_ARCH% + +for %%d in (openssl %destdir%) do ( + if exist %%d ( + rmdir /s /q %%d + ) +) + +if %PYTHON_ARCH% == 64 ( + set OPENSSL_CONFIG=VC-WIN64A + set VC_ARCH=x64 +) else ( + set OPENSSL_CONFIG=VC-WIN32 + set VC_ARCH=x86 +) + +call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" %VC_ARCH% +SET PATH=%PATH%;C:\Program Files\NASM + +if not exist %cachedir% ( + mkdir openssl + curl -L https://www.openssl.org/source/openssl-1.1.1f.tar.gz -o openssl.tar.gz + tar xzf openssl.tar.gz -C openssl --strip-components 1 + cd openssl + + perl Configure no-comp no-shared no-tests %OPENSSL_CONFIG% + nmake + + mkdir %cachedir% + mkdir %cachedir%\include + mkdir %cachedir%\lib + xcopy include %cachedir%\include\ /E + copy libcrypto.lib %cachedir%\lib\ + copy libssl.lib %cachedir%\lib\ +) + +mkdir %destdir% +xcopy %cachedir% %destdir% /E diff --git a/testing/web-platform/tests/tools/third_party/aioquic/setup.cfg b/testing/web-platform/tests/tools/third_party/aioquic/setup.cfg new file mode 100644 index 000000000000..89c025ad0f73 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/setup.cfg @@ -0,0 +1,21 @@ +[coverage:run] +source = aioquic + +[flake8] +ignore=E203,W503 +max-line-length=150 + +[isort] +default_section = THIRDPARTY +include_trailing_comma = True +known_first_party = aioquic +line_length = 88 +multi_line_output = 3 + +[mypy] +disallow_untyped_calls = True +disallow_untyped_decorators = True +ignore_missing_imports = True +strict_optional = False +warn_redundant_casts = True +warn_unused_ignores = True diff --git a/testing/web-platform/tests/tools/third_party/aioquic/setup.py b/testing/web-platform/tests/tools/third_party/aioquic/setup.py new file mode 100644 index 000000000000..a5c5ea3c443f --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/setup.py @@ -0,0 +1,69 @@ +import os.path +import sys + +import setuptools + +root_dir = os.path.abspath(os.path.dirname(__file__)) + +about = {} +about_file = os.path.join(root_dir, "src", "aioquic", "about.py") +with open(about_file, encoding="utf-8") as fp: + exec(fp.read(), about) + +readme_file = os.path.join(root_dir, "README.rst") +with open(readme_file, encoding="utf-8") as f: + long_description = f.read() + +if sys.platform == "win32": + extra_compile_args = [] + libraries = ["libcrypto", "advapi32", "crypt32", "gdi32", "user32", "ws2_32"] +else: + extra_compile_args = ["-std=c99"] + libraries = ["crypto"] + +setuptools.setup( + name=about["__title__"], + version=about["__version__"], + description=about["__summary__"], + long_description=long_description, + url=about["__uri__"], + author=about["__author__"], + author_email=about["__email__"], + license=about["__license__"], + include_package_data=True, + classifiers=[ + "Development Status :: 4 - Beta", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Internet :: WWW/HTTP", + ], + ext_modules=[ + setuptools.Extension( + "aioquic._buffer", + extra_compile_args=extra_compile_args, + sources=["src/aioquic/_buffer.c"], + ), + setuptools.Extension( + "aioquic._crypto", + extra_compile_args=extra_compile_args, + libraries=libraries, + sources=["src/aioquic/_crypto.c"], + ), + ], + package_dir={"": "src"}, + package_data={"aioquic": ["py.typed", "_buffer.pyi", "_crypto.pyi"]}, + packages=["aioquic", "aioquic.asyncio", "aioquic.h0", "aioquic.h3", "aioquic.quic"], + install_requires=[ + "certifi", + "cryptography >= 2.5", + 'dataclasses; python_version < "3.7"', + "pylsqpack >= 0.3.3, < 0.4.0", + ], +) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/__init__.py new file mode 100644 index 000000000000..04acebeef83e --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .about import __version__ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.c b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.c new file mode 100644 index 000000000000..eefc4db81759 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.c @@ -0,0 +1,440 @@ +#define PY_SSIZE_T_CLEAN + +#include +#include + +#define MODULE_NAME "aioquic._buffer" + +static PyObject *BufferReadError; +static PyObject *BufferWriteError; + +typedef struct { + PyObject_HEAD + uint8_t *base; + uint8_t *end; + uint8_t *pos; +} BufferObject; + +#define CHECK_READ_BOUNDS(self, len) \ + if (len < 0 || self->pos + len > self->end) { \ + PyErr_SetString(BufferReadError, "Read out of bounds"); \ + return NULL; \ + } + +#define CHECK_WRITE_BOUNDS(self, len) \ + if (self->pos + len > self->end) { \ + PyErr_SetString(BufferWriteError, "Write out of bounds"); \ + return NULL; \ + } + +static int +Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs) +{ + const char *kwlist[] = {"capacity", "data", NULL}; + int capacity = 0; + const unsigned char *data = NULL; + Py_ssize_t data_len = 0; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iy#", (char**)kwlist, &capacity, &data, &data_len)) + return -1; + + if (data != NULL) { + self->base = malloc(data_len); + self->end = self->base + data_len; + memcpy(self->base, data, data_len); + } else { + self->base = malloc(capacity); + self->end = self->base + capacity; + } + self->pos = self->base; + return 0; +} + +static void +Buffer_dealloc(BufferObject *self) +{ + free(self->base); +} + +static PyObject * +Buffer_data_slice(BufferObject *self, PyObject *args) +{ + int start, stop; + if (!PyArg_ParseTuple(args, "ii", &start, &stop)) + return NULL; + + if (start < 0 || self->base + start > self->end || + stop < 0 || self->base + stop > self->end || + stop < start) { + PyErr_SetString(BufferReadError, "Read out of bounds"); + return NULL; + } + + return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start)); +} + +static PyObject * +Buffer_eof(BufferObject *self, PyObject *args) +{ + if (self->pos == self->end) + Py_RETURN_TRUE; + Py_RETURN_FALSE; +} + +static PyObject * +Buffer_pull_bytes(BufferObject *self, PyObject *args) +{ + int len; + if (!PyArg_ParseTuple(args, "i", &len)) + return NULL; + + CHECK_READ_BOUNDS(self, len); + + PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len); + self->pos += len; + return o; +} + +static PyObject * +Buffer_pull_uint8(BufferObject *self, PyObject *args) +{ + CHECK_READ_BOUNDS(self, 1) + + return PyLong_FromUnsignedLong( + (uint8_t)(*(self->pos++)) + ); +} + +static PyObject * +Buffer_pull_uint16(BufferObject *self, PyObject *args) +{ + CHECK_READ_BOUNDS(self, 2) + + uint16_t value = (uint16_t)(*(self->pos)) << 8 | + (uint16_t)(*(self->pos + 1)); + self->pos += 2; + return PyLong_FromUnsignedLong(value); +} + +static PyObject * +Buffer_pull_uint32(BufferObject *self, PyObject *args) +{ + CHECK_READ_BOUNDS(self, 4) + + uint32_t value = (uint32_t)(*(self->pos)) << 24 | + (uint32_t)(*(self->pos + 1)) << 16 | + (uint32_t)(*(self->pos + 2)) << 8 | + (uint32_t)(*(self->pos + 3)); + self->pos += 4; + return PyLong_FromUnsignedLong(value); +} + +static PyObject * +Buffer_pull_uint64(BufferObject *self, PyObject *args) +{ + CHECK_READ_BOUNDS(self, 8) + + uint64_t value = (uint64_t)(*(self->pos)) << 56 | + (uint64_t)(*(self->pos + 1)) << 48 | + (uint64_t)(*(self->pos + 2)) << 40 | + (uint64_t)(*(self->pos + 3)) << 32 | + (uint64_t)(*(self->pos + 4)) << 24 | + (uint64_t)(*(self->pos + 5)) << 16 | + (uint64_t)(*(self->pos + 6)) << 8 | + (uint64_t)(*(self->pos + 7)); + self->pos += 8; + return PyLong_FromUnsignedLongLong(value); +} + +static PyObject * +Buffer_pull_uint_var(BufferObject *self, PyObject *args) +{ + uint64_t value; + CHECK_READ_BOUNDS(self, 1) + switch (*(self->pos) >> 6) { + case 0: + value = *(self->pos++) & 0x3F; + break; + case 1: + CHECK_READ_BOUNDS(self, 2) + value = (uint16_t)(*(self->pos) & 0x3F) << 8 | + (uint16_t)(*(self->pos + 1)); + self->pos += 2; + break; + case 2: + CHECK_READ_BOUNDS(self, 4) + value = (uint32_t)(*(self->pos) & 0x3F) << 24 | + (uint32_t)(*(self->pos + 1)) << 16 | + (uint32_t)(*(self->pos + 2)) << 8 | + (uint32_t)(*(self->pos + 3)); + self->pos += 4; + break; + default: + CHECK_READ_BOUNDS(self, 8) + value = (uint64_t)(*(self->pos) & 0x3F) << 56 | + (uint64_t)(*(self->pos + 1)) << 48 | + (uint64_t)(*(self->pos + 2)) << 40 | + (uint64_t)(*(self->pos + 3)) << 32 | + (uint64_t)(*(self->pos + 4)) << 24 | + (uint64_t)(*(self->pos + 5)) << 16 | + (uint64_t)(*(self->pos + 6)) << 8 | + (uint64_t)(*(self->pos + 7)); + self->pos += 8; + break; + } + return PyLong_FromUnsignedLongLong(value); +} + +static PyObject * +Buffer_push_bytes(BufferObject *self, PyObject *args) +{ + const unsigned char *data; + Py_ssize_t data_len; + if (!PyArg_ParseTuple(args, "y#", &data, &data_len)) + return NULL; + + CHECK_WRITE_BOUNDS(self, data_len) + + memcpy(self->pos, data, data_len); + self->pos += data_len; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_push_uint8(BufferObject *self, PyObject *args) +{ + uint8_t value; + if (!PyArg_ParseTuple(args, "B", &value)) + return NULL; + + CHECK_WRITE_BOUNDS(self, 1) + + *(self->pos++) = value; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_push_uint16(BufferObject *self, PyObject *args) +{ + uint16_t value; + if (!PyArg_ParseTuple(args, "H", &value)) + return NULL; + + CHECK_WRITE_BOUNDS(self, 2) + + *(self->pos++) = (value >> 8); + *(self->pos++) = value; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_push_uint32(BufferObject *self, PyObject *args) +{ + uint32_t value; + if (!PyArg_ParseTuple(args, "I", &value)) + return NULL; + + CHECK_WRITE_BOUNDS(self, 4) + *(self->pos++) = (value >> 24); + *(self->pos++) = (value >> 16); + *(self->pos++) = (value >> 8); + *(self->pos++) = value; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_push_uint64(BufferObject *self, PyObject *args) +{ + uint64_t value; + if (!PyArg_ParseTuple(args, "K", &value)) + return NULL; + + CHECK_WRITE_BOUNDS(self, 8) + *(self->pos++) = (value >> 56); + *(self->pos++) = (value >> 48); + *(self->pos++) = (value >> 40); + *(self->pos++) = (value >> 32); + *(self->pos++) = (value >> 24); + *(self->pos++) = (value >> 16); + *(self->pos++) = (value >> 8); + *(self->pos++) = value; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_push_uint_var(BufferObject *self, PyObject *args) +{ + uint64_t value; + if (!PyArg_ParseTuple(args, "K", &value)) + return NULL; + + if (value <= 0x3F) { + CHECK_WRITE_BOUNDS(self, 1) + *(self->pos++) = value; + Py_RETURN_NONE; + } else if (value <= 0x3FFF) { + CHECK_WRITE_BOUNDS(self, 2) + *(self->pos++) = (value >> 8) | 0x40; + *(self->pos++) = value; + Py_RETURN_NONE; + } else if (value <= 0x3FFFFFFF) { + CHECK_WRITE_BOUNDS(self, 4) + *(self->pos++) = (value >> 24) | 0x80; + *(self->pos++) = (value >> 16); + *(self->pos++) = (value >> 8); + *(self->pos++) = value; + Py_RETURN_NONE; + } else if (value <= 0x3FFFFFFFFFFFFFFF) { + CHECK_WRITE_BOUNDS(self, 8) + *(self->pos++) = (value >> 56) | 0xC0; + *(self->pos++) = (value >> 48); + *(self->pos++) = (value >> 40); + *(self->pos++) = (value >> 32); + *(self->pos++) = (value >> 24); + *(self->pos++) = (value >> 16); + *(self->pos++) = (value >> 8); + *(self->pos++) = value; + Py_RETURN_NONE; + } else { + PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer"); + return NULL; + } +} + +static PyObject * +Buffer_seek(BufferObject *self, PyObject *args) +{ + int pos; + if (!PyArg_ParseTuple(args, "i", &pos)) + return NULL; + + if (pos < 0 || self->base + pos > self->end) { + PyErr_SetString(BufferReadError, "Seek out of bounds"); + return NULL; + } + + self->pos = self->base + pos; + Py_RETURN_NONE; +} + +static PyObject * +Buffer_tell(BufferObject *self, PyObject *args) +{ + return PyLong_FromSsize_t(self->pos - self->base); +} + +static PyMethodDef Buffer_methods[] = { + {"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""}, + {"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""}, + {"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."}, + {"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."}, + {"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."}, + {"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."}, + {"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."}, + {"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."}, + {"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."}, + {"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."}, + {"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."}, + {"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."}, + {"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."}, + {"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."}, + {"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""}, + {"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""}, + {NULL} +}; + +static PyObject* +Buffer_capacity_getter(BufferObject* self, void *closure) { + return PyLong_FromSsize_t(self->end - self->base); +} + +static PyObject* +Buffer_data_getter(BufferObject* self, void *closure) { + return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base); +} + +static PyGetSetDef Buffer_getset[] = { + {"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL }, + {"data", (getter) Buffer_data_getter, NULL, "", NULL }, + {NULL} +}; + +static PyTypeObject BufferType = { + PyVarObject_HEAD_INIT(NULL, 0) + MODULE_NAME ".Buffer", /* tp_name */ + sizeof(BufferObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)Buffer_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "Buffer objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Buffer_methods, /* tp_methods */ + 0, /* tp_members */ + Buffer_getset, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Buffer_init, /* tp_init */ + 0, /* tp_alloc */ +}; + + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, /* m_name */ + "A faster buffer.", /* m_doc */ + -1, /* m_size */ + NULL, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ +}; + + +PyMODINIT_FUNC +PyInit__buffer(void) +{ + PyObject* m; + + m = PyModule_Create(&moduledef); + if (m == NULL) + return NULL; + + BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL); + Py_INCREF(BufferReadError); + PyModule_AddObject(m, "BufferReadError", BufferReadError); + + BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL); + Py_INCREF(BufferWriteError); + PyModule_AddObject(m, "BufferWriteError", BufferWriteError); + + BufferType.tp_new = PyType_GenericNew; + if (PyType_Ready(&BufferType) < 0) + return NULL; + Py_INCREF(&BufferType); + PyModule_AddObject(m, "Buffer", (PyObject *)&BufferType); + + return m; +} diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.pyi b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.pyi new file mode 100644 index 000000000000..a6e43cca3a51 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_buffer.pyi @@ -0,0 +1,27 @@ +from typing import Optional + +class BufferReadError(ValueError): ... +class BufferWriteError(ValueError): ... + +class Buffer: + def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ... + @property + def capacity(self) -> int: ... + @property + def data(self) -> bytes: ... + def data_slice(self, start: int, end: int) -> bytes: ... + def eof(self) -> bool: ... + def seek(self, pos: int) -> None: ... + def tell(self) -> int: ... + def pull_bytes(self, length: int) -> bytes: ... + def pull_uint8(self) -> int: ... + def pull_uint16(self) -> int: ... + def pull_uint32(self) -> int: ... + def pull_uint64(self) -> int: ... + def pull_uint_var(self) -> int: ... + def push_bytes(self, value: bytes) -> None: ... + def push_uint8(self, value: int) -> None: ... + def push_uint16(self, value: int) -> None: ... + def push_uint32(self, v: int) -> None: ... + def push_uint64(self, v: int) -> None: ... + def push_uint_var(self, value: int) -> None: ... diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.c b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.c new file mode 100644 index 000000000000..5e1554d4b213 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.c @@ -0,0 +1,455 @@ +#define PY_SSIZE_T_CLEAN + +#include +#include +#include + +#define MODULE_NAME "aioquic._crypto" + +#define AEAD_KEY_LENGTH_MAX 32 +#define AEAD_NONCE_LENGTH 12 +#define AEAD_TAG_LENGTH 16 + +#define PACKET_LENGTH_MAX 1500 +#define PACKET_NUMBER_LENGTH_MAX 4 +#define SAMPLE_LENGTH 16 + +#define CHECK_RESULT(expr) \ + if (!(expr)) { \ + ERR_clear_error(); \ + PyErr_SetString(CryptoError, "OpenSSL call failed"); \ + return NULL; \ + } + +#define CHECK_RESULT_CTOR(expr) \ + if (!(expr)) { \ + ERR_clear_error(); \ + PyErr_SetString(CryptoError, "OpenSSL call failed"); \ + return -1; \ + } + +static PyObject *CryptoError; + +/* AEAD */ + +typedef struct { + PyObject_HEAD + EVP_CIPHER_CTX *decrypt_ctx; + EVP_CIPHER_CTX *encrypt_ctx; + unsigned char buffer[PACKET_LENGTH_MAX]; + unsigned char key[AEAD_KEY_LENGTH_MAX]; + unsigned char iv[AEAD_NONCE_LENGTH]; + unsigned char nonce[AEAD_NONCE_LENGTH]; +} AEADObject; + +static EVP_CIPHER_CTX * +create_ctx(const EVP_CIPHER *cipher, int key_length, int operation) +{ + EVP_CIPHER_CTX *ctx; + int res; + + ctx = EVP_CIPHER_CTX_new(); + CHECK_RESULT(ctx != 0); + + res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation); + CHECK_RESULT(res != 0); + + res = EVP_CIPHER_CTX_set_key_length(ctx, key_length); + CHECK_RESULT(res != 0); + + res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL); + CHECK_RESULT(res != 0); + + return ctx; +} + +static int +AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs) +{ + const char *cipher_name; + const unsigned char *key, *iv; + Py_ssize_t cipher_name_len, key_len, iv_len; + + if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len)) + return -1; + + const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name); + if (evp_cipher == 0) { + PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name); + return -1; + } + if (key_len > AEAD_KEY_LENGTH_MAX) { + PyErr_SetString(CryptoError, "Invalid key length"); + return -1; + } + if (iv_len > AEAD_NONCE_LENGTH) { + PyErr_SetString(CryptoError, "Invalid iv length"); + return -1; + } + + memcpy(self->key, key, key_len); + memcpy(self->iv, iv, iv_len); + + self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0); + CHECK_RESULT_CTOR(self->decrypt_ctx != 0); + + self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1); + CHECK_RESULT_CTOR(self->encrypt_ctx != 0); + + return 0; +} + +static void +AEAD_dealloc(AEADObject *self) +{ + EVP_CIPHER_CTX_free(self->decrypt_ctx); + EVP_CIPHER_CTX_free(self->encrypt_ctx); +} + +static PyObject* +AEAD_decrypt(AEADObject *self, PyObject *args) +{ + const unsigned char *data, *associated; + Py_ssize_t data_len, associated_len; + int outlen, outlen2, res; + uint64_t pn; + + if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn)) + return NULL; + + if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) { + PyErr_SetString(CryptoError, "Invalid payload length"); + return NULL; + } + + memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH); + for (int i = 0; i < 8; ++i) { + self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i); + } + + res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH))); + CHECK_RESULT(res != 0); + + res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0); + CHECK_RESULT(res != 0); + + res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len); + CHECK_RESULT(res != 0); + + res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH); + CHECK_RESULT(res != 0); + + res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2); + if (res == 0) { + PyErr_SetString(CryptoError, "Payload decryption failed"); + return NULL; + } + + return PyBytes_FromStringAndSize((const char*)self->buffer, outlen); +} + +static PyObject* +AEAD_encrypt(AEADObject *self, PyObject *args) +{ + const unsigned char *data, *associated; + Py_ssize_t data_len, associated_len; + int outlen, outlen2, res; + uint64_t pn; + + if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn)) + return NULL; + + if (data_len > PACKET_LENGTH_MAX) { + PyErr_SetString(CryptoError, "Invalid payload length"); + return NULL; + } + + memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH); + for (int i = 0; i < 8; ++i) { + self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i); + } + + res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1); + CHECK_RESULT(res != 0); + + res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len); + CHECK_RESULT(res != 0); + + res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len); + CHECK_RESULT(res != 0); + + res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2); + CHECK_RESULT(res != 0 && outlen2 == 0); + + res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen); + CHECK_RESULT(res != 0); + + return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH); +} + +static PyMethodDef AEAD_methods[] = { + {"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""}, + {"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""}, + + {NULL} +}; + +static PyTypeObject AEADType = { + PyVarObject_HEAD_INIT(NULL, 0) + MODULE_NAME ".AEAD", /* tp_name */ + sizeof(AEADObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)AEAD_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "AEAD objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + AEAD_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)AEAD_init, /* tp_init */ + 0, /* tp_alloc */ +}; + +/* HeaderProtection */ + +typedef struct { + PyObject_HEAD + EVP_CIPHER_CTX *ctx; + int is_chacha20; + unsigned char buffer[PACKET_LENGTH_MAX]; + unsigned char mask[31]; + unsigned char zero[5]; +} HeaderProtectionObject; + +static int +HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs) +{ + const char *cipher_name; + const unsigned char *key; + Py_ssize_t cipher_name_len, key_len; + int res; + + if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len)) + return -1; + + const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name); + if (evp_cipher == 0) { + PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name); + return -1; + } + + memset(self->mask, 0, sizeof(self->mask)); + memset(self->zero, 0, sizeof(self->zero)); + self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0; + + self->ctx = EVP_CIPHER_CTX_new(); + CHECK_RESULT_CTOR(self->ctx != 0); + + res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1); + CHECK_RESULT_CTOR(res != 0); + + res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len); + CHECK_RESULT_CTOR(res != 0); + + res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1); + CHECK_RESULT_CTOR(res != 0); + + return 0; +} + +static void +HeaderProtection_dealloc(HeaderProtectionObject *self) +{ + EVP_CIPHER_CTX_free(self->ctx); +} + +static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample) +{ + int outlen; + if (self->is_chacha20) { + return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) && + EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero)); + } else { + return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH); + } +} + +static PyObject* +HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args) +{ + const unsigned char *header, *payload; + Py_ssize_t header_len, payload_len; + int res; + + if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len)) + return NULL; + + int pn_length = (header[0] & 0x03) + 1; + int pn_offset = header_len - pn_length; + + res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length); + CHECK_RESULT(res != 0); + + memcpy(self->buffer, header, header_len); + memcpy(self->buffer + header_len, payload, payload_len); + + if (self->buffer[0] & 0x80) { + self->buffer[0] ^= self->mask[0] & 0x0F; + } else { + self->buffer[0] ^= self->mask[0] & 0x1F; + } + + for (int i = 0; i < pn_length; ++i) { + self->buffer[pn_offset + i] ^= self->mask[1 + i]; + } + + return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len); +} + +static PyObject* +HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args) +{ + const unsigned char *packet; + Py_ssize_t packet_len; + int pn_offset, res; + + if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset)) + return NULL; + + res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX); + CHECK_RESULT(res != 0); + + memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX); + + if (self->buffer[0] & 0x80) { + self->buffer[0] ^= self->mask[0] & 0x0F; + } else { + self->buffer[0] ^= self->mask[0] & 0x1F; + } + + int pn_length = (self->buffer[0] & 0x03) + 1; + uint32_t pn_truncated = 0; + for (int i = 0; i < pn_length; ++i) { + self->buffer[pn_offset + i] ^= self->mask[1 + i]; + pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8); + } + + return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated); +} + +static PyMethodDef HeaderProtection_methods[] = { + {"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""}, + {"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""}, + {NULL} +}; + +static PyTypeObject HeaderProtectionType = { + PyVarObject_HEAD_INIT(NULL, 0) + MODULE_NAME ".HeaderProtection", /* tp_name */ + sizeof(HeaderProtectionObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)HeaderProtection_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "HeaderProtection objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + HeaderProtection_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)HeaderProtection_init, /* tp_init */ + 0, /* tp_alloc */ +}; + + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, /* m_name */ + "A faster buffer.", /* m_doc */ + -1, /* m_size */ + NULL, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ +}; + +PyMODINIT_FUNC +PyInit__crypto(void) +{ + PyObject* m; + + m = PyModule_Create(&moduledef); + if (m == NULL) + return NULL; + + CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL); + Py_INCREF(CryptoError); + PyModule_AddObject(m, "CryptoError", CryptoError); + + AEADType.tp_new = PyType_GenericNew; + if (PyType_Ready(&AEADType) < 0) + return NULL; + Py_INCREF(&AEADType); + PyModule_AddObject(m, "AEAD", (PyObject *)&AEADType); + + HeaderProtectionType.tp_new = PyType_GenericNew; + if (PyType_Ready(&HeaderProtectionType) < 0) + return NULL; + Py_INCREF(&HeaderProtectionType); + PyModule_AddObject(m, "HeaderProtection", (PyObject *)&HeaderProtectionType); + + // ensure required ciphers are initialised + EVP_add_cipher(EVP_aes_128_ecb()); + EVP_add_cipher(EVP_aes_128_gcm()); + EVP_add_cipher(EVP_aes_256_ecb()); + EVP_add_cipher(EVP_aes_256_gcm()); + + return m; +} diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.pyi b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.pyi new file mode 100644 index 000000000000..32c5230d992e --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/_crypto.pyi @@ -0,0 +1,17 @@ +from typing import Tuple + +class AEAD: + def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ... + def decrypt( + self, data: bytes, associated_data: bytes, packet_number: int + ) -> bytes: ... + def encrypt( + self, data: bytes, associated_data: bytes, packet_number: int + ) -> bytes: ... + +class CryptoError(ValueError): ... + +class HeaderProtection: + def __init__(self, cipher_name: bytes, key: bytes): ... + def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ... + def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ... diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/about.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/about.py new file mode 100644 index 000000000000..82cd8f33d65d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/about.py @@ -0,0 +1,7 @@ +__author__ = "Jeremy Lainé" +__email__ = "jeremy.laine@m4x.org" +__license__ = "BSD" +__summary__ = "An implementation of QUIC and HTTP/3" +__title__ = "aioquic" +__uri__ = "https://github.com/aiortc/aioquic" +__version__ = "0.8.7" diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/__init__.py new file mode 100644 index 000000000000..3fda69d89934 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/__init__.py @@ -0,0 +1,3 @@ +from .client import connect # noqa +from .protocol import QuicConnectionProtocol # noqa +from .server import serve # noqa diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/client.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/client.py new file mode 100644 index 000000000000..42937e9e892b --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/client.py @@ -0,0 +1,94 @@ +import asyncio +import ipaddress +import socket +import sys +from typing import AsyncGenerator, Callable, Optional, cast + +from ..quic.configuration import QuicConfiguration +from ..quic.connection import QuicConnection +from ..tls import SessionTicketHandler +from .compat import asynccontextmanager +from .protocol import QuicConnectionProtocol, QuicStreamHandler + +__all__ = ["connect"] + + +@asynccontextmanager +async def connect( + host: str, + port: int, + *, + configuration: Optional[QuicConfiguration] = None, + create_protocol: Optional[Callable] = QuicConnectionProtocol, + session_ticket_handler: Optional[SessionTicketHandler] = None, + stream_handler: Optional[QuicStreamHandler] = None, + wait_connected: bool = True, + local_port: int = 0, +) -> AsyncGenerator[QuicConnectionProtocol, None]: + """ + Connect to a QUIC server at the given `host` and `port`. + + :meth:`connect()` returns an awaitable. Awaiting it yields a + :class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to + create streams. + + :func:`connect` also accepts the following optional arguments: + + * ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration` + configuration object. + * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that + manages the connection. It should be a callable or class accepting the same + arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning + an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. + * ``session_ticket_handler`` is a callback which is invoked by the TLS + engine when a new session ticket is received. + * ``stream_handler`` is a callback which is invoked whenever a stream is + created. It must accept two arguments: a :class:`asyncio.StreamReader` + and a :class:`asyncio.StreamWriter`. + * ``local_port`` is the UDP port number that this client wants to bind. + """ + loop = asyncio.get_event_loop() + local_host = "::" + + # if host is not an IP address, pass it to enable SNI + try: + ipaddress.ip_address(host) + server_name = None + except ValueError: + server_name = host + + # lookup remote address + infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM) + addr = infos[0][4] + if len(addr) == 2: + # determine behaviour for IPv4 + if sys.platform == "win32": + # on Windows, we must use an IPv4 socket to reach an IPv4 host + local_host = "0.0.0.0" + else: + # other platforms support dual-stack sockets + addr = ("::ffff:" + addr[0], addr[1], 0, 0) + + # prepare QUIC connection + if configuration is None: + configuration = QuicConfiguration(is_client=True) + if server_name is not None: + configuration.server_name = server_name + connection = QuicConnection( + configuration=configuration, session_ticket_handler=session_ticket_handler + ) + + # connect + _, protocol = await loop.create_datagram_endpoint( + lambda: create_protocol(connection, stream_handler=stream_handler), + local_addr=(local_host, local_port), + ) + protocol = cast(QuicConnectionProtocol, protocol) + protocol.connect(addr) + if wait_connected: + await protocol.wait_connected() + try: + yield protocol + finally: + protocol.close() + await protocol.wait_closed() diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/compat.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/compat.py new file mode 100644 index 000000000000..201e62f6d942 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/compat.py @@ -0,0 +1,33 @@ +from contextlib import ContextDecorator +from functools import wraps + +try: + from contextlib import asynccontextmanager +except ImportError: + asynccontextmanager = None + + +class _AsyncGeneratorContextManager(ContextDecorator): + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + self.__doc__ = func.__doc__ + + async def __aenter__(self): + return await self.gen.__anext__() + + async def __aexit__(self, typ, value, traceback): + if typ is not None: + await self.gen.athrow(typ, value, traceback) + + +def _asynccontextmanager(func): + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + + return helper + + +if asynccontextmanager is None: + asynccontextmanager = _asynccontextmanager diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/protocol.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/protocol.py new file mode 100644 index 000000000000..be004335b8ff --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/protocol.py @@ -0,0 +1,240 @@ +import asyncio +from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast + +from ..quic import events +from ..quic.connection import NetworkAddress, QuicConnection + +QuicConnectionIdHandler = Callable[[bytes], None] +QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None] + + +class QuicConnectionProtocol(asyncio.DatagramProtocol): + def __init__( + self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None + ): + loop = asyncio.get_event_loop() + + self._closed = asyncio.Event() + self._connected = False + self._connected_waiter: Optional[asyncio.Future[None]] = None + self._loop = loop + self._ping_waiters: Dict[int, asyncio.Future[None]] = {} + self._quic = quic + self._stream_readers: Dict[int, asyncio.StreamReader] = {} + self._timer: Optional[asyncio.TimerHandle] = None + self._timer_at: Optional[float] = None + self._transmit_task: Optional[asyncio.Handle] = None + self._transport: Optional[asyncio.DatagramTransport] = None + + # callbacks + self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None + self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None + self._connection_terminated_handler: Callable[[], None] = lambda: None + if stream_handler is not None: + self._stream_handler = stream_handler + else: + self._stream_handler = lambda r, w: None + + def change_connection_id(self) -> None: + """ + Change the connection ID used to communicate with the peer. + + The previous connection ID will be retired. + """ + self._quic.change_connection_id() + self.transmit() + + def close(self) -> None: + """ + Close the connection. + """ + self._quic.close() + self.transmit() + + def connect(self, addr: NetworkAddress) -> None: + """ + Initiate the TLS handshake. + + This method can only be called for clients and a single time. + """ + self._quic.connect(addr, now=self._loop.time()) + self.transmit() + + async def create_stream( + self, is_unidirectional: bool = False + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """ + Create a QUIC stream and return a pair of (reader, writer) objects. + + The returned reader and writer objects are instances of :class:`asyncio.StreamReader` + and :class:`asyncio.StreamWriter` classes. + """ + stream_id = self._quic.get_next_available_stream_id( + is_unidirectional=is_unidirectional + ) + return self._create_stream(stream_id) + + def request_key_update(self) -> None: + """ + Request an update of the encryption keys. + """ + self._quic.request_key_update() + self.transmit() + + async def ping(self) -> None: + """ + Ping the peer and wait for the response. + """ + waiter = self._loop.create_future() + uid = id(waiter) + self._ping_waiters[uid] = waiter + self._quic.send_ping(uid) + self.transmit() + await asyncio.shield(waiter) + + def transmit(self) -> None: + """ + Send pending datagrams to the peer and arm the timer if needed. + """ + self._transmit_task = None + + # send datagrams + for data, addr in self._quic.datagrams_to_send(now=self._loop.time()): + self._transport.sendto(data, addr) + + # re-arm timer + timer_at = self._quic.get_timer() + if self._timer is not None and self._timer_at != timer_at: + self._timer.cancel() + self._timer = None + if self._timer is None and timer_at is not None: + self._timer = self._loop.call_at(timer_at, self._handle_timer) + self._timer_at = timer_at + + async def wait_closed(self) -> None: + """ + Wait for the connection to be closed. + """ + await self._closed.wait() + + async def wait_connected(self) -> None: + """ + Wait for the TLS handshake to complete. + """ + assert self._connected_waiter is None, "already awaiting connected" + if not self._connected: + self._connected_waiter = self._loop.create_future() + await asyncio.shield(self._connected_waiter) + + # asyncio.Transport + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self._transport = cast(asyncio.DatagramTransport, transport) + + def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: + self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time()) + self._process_events() + self.transmit() + + # overridable + + def quic_event_received(self, event: events.QuicEvent) -> None: + """ + Called when a QUIC event is received. + + Reimplement this in your subclass to handle the events. + """ + # FIXME: move this to a subclass + if isinstance(event, events.ConnectionTerminated): + for reader in self._stream_readers.values(): + reader.feed_eof() + elif isinstance(event, events.StreamDataReceived): + reader = self._stream_readers.get(event.stream_id, None) + if reader is None: + reader, writer = self._create_stream(event.stream_id) + self._stream_handler(reader, writer) + reader.feed_data(event.data) + if event.end_stream: + reader.feed_eof() + + # private + + def _create_stream( + self, stream_id: int + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + adapter = QuicStreamAdapter(self, stream_id) + reader = asyncio.StreamReader() + writer = asyncio.StreamWriter(adapter, None, reader, self._loop) + self._stream_readers[stream_id] = reader + return reader, writer + + def _handle_timer(self) -> None: + now = max(self._timer_at, self._loop.time()) + self._timer = None + self._timer_at = None + self._quic.handle_timer(now=now) + self._process_events() + self.transmit() + + def _process_events(self) -> None: + event = self._quic.next_event() + while event is not None: + if isinstance(event, events.ConnectionIdIssued): + self._connection_id_issued_handler(event.connection_id) + elif isinstance(event, events.ConnectionIdRetired): + self._connection_id_retired_handler(event.connection_id) + elif isinstance(event, events.ConnectionTerminated): + self._connection_terminated_handler() + + # abort connection waiter + if self._connected_waiter is not None: + waiter = self._connected_waiter + self._connected_waiter = None + waiter.set_exception(ConnectionError) + + # abort ping waiters + for waiter in self._ping_waiters.values(): + waiter.set_exception(ConnectionError) + self._ping_waiters.clear() + + self._closed.set() + elif isinstance(event, events.HandshakeCompleted): + if self._connected_waiter is not None: + waiter = self._connected_waiter + self._connected = True + self._connected_waiter = None + waiter.set_result(None) + elif isinstance(event, events.PingAcknowledged): + waiter = self._ping_waiters.pop(event.uid, None) + if waiter is not None: + waiter.set_result(None) + self.quic_event_received(event) + event = self._quic.next_event() + + def _transmit_soon(self) -> None: + if self._transmit_task is None: + self._transmit_task = self._loop.call_soon(self.transmit) + + +class QuicStreamAdapter(asyncio.Transport): + def __init__(self, protocol: QuicConnectionProtocol, stream_id: int): + self.protocol = protocol + self.stream_id = stream_id + + def can_write_eof(self) -> bool: + return True + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """ + Get information about the underlying QUIC stream. + """ + if name == "stream_id": + return self.stream_id + + def write(self, data): + self.protocol._quic.send_stream_data(self.stream_id, data) + self.protocol._transmit_soon() + + def write_eof(self): + self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True) + self.protocol._transmit_soon() diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/server.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/server.py new file mode 100644 index 000000000000..240ee460b6ed --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/asyncio/server.py @@ -0,0 +1,210 @@ +import asyncio +import os +from functools import partial +from typing import Callable, Dict, Optional, Text, Union, cast + +from ..buffer import Buffer +from ..quic.configuration import QuicConfiguration +from ..quic.connection import NetworkAddress, QuicConnection +from ..quic.packet import ( + PACKET_TYPE_INITIAL, + encode_quic_retry, + encode_quic_version_negotiation, + pull_quic_header, +) +from ..quic.retry import QuicRetryTokenHandler +from ..tls import SessionTicketFetcher, SessionTicketHandler +from .protocol import QuicConnectionProtocol, QuicStreamHandler + +__all__ = ["serve"] + + +class QuicServer(asyncio.DatagramProtocol): + def __init__( + self, + *, + configuration: QuicConfiguration, + create_protocol: Callable = QuicConnectionProtocol, + session_ticket_fetcher: Optional[SessionTicketFetcher] = None, + session_ticket_handler: Optional[SessionTicketHandler] = None, + stateless_retry: bool = False, + stream_handler: Optional[QuicStreamHandler] = None, + ) -> None: + self._configuration = configuration + self._create_protocol = create_protocol + self._loop = asyncio.get_event_loop() + self._protocols: Dict[bytes, QuicConnectionProtocol] = {} + self._session_ticket_fetcher = session_ticket_fetcher + self._session_ticket_handler = session_ticket_handler + self._transport: Optional[asyncio.DatagramTransport] = None + + self._stream_handler = stream_handler + + if stateless_retry: + self._retry = QuicRetryTokenHandler() + else: + self._retry = None + + def close(self): + for protocol in set(self._protocols.values()): + protocol.close() + self._protocols.clear() + self._transport.close() + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self._transport = cast(asyncio.DatagramTransport, transport) + + def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None: + data = cast(bytes, data) + buf = Buffer(data=data) + + try: + header = pull_quic_header( + buf, host_cid_length=self._configuration.connection_id_length + ) + except ValueError: + return + + # version negotiation + if ( + header.version is not None + and header.version not in self._configuration.supported_versions + ): + self._transport.sendto( + encode_quic_version_negotiation( + source_cid=header.destination_cid, + destination_cid=header.source_cid, + supported_versions=self._configuration.supported_versions, + ), + addr, + ) + return + + protocol = self._protocols.get(header.destination_cid, None) + original_connection_id: Optional[bytes] = None + if ( + protocol is None + and len(data) >= 1200 + and header.packet_type == PACKET_TYPE_INITIAL + ): + # stateless retry + if self._retry is not None: + if not header.token: + # create a retry token + self._transport.sendto( + encode_quic_retry( + version=header.version, + source_cid=os.urandom(8), + destination_cid=header.source_cid, + original_destination_cid=header.destination_cid, + retry_token=self._retry.create_token( + addr, header.destination_cid + ), + ), + addr, + ) + return + else: + # validate retry token + try: + original_connection_id = self._retry.validate_token( + addr, header.token + ) + except ValueError: + return + + # create new connection + connection = QuicConnection( + configuration=self._configuration, + logger_connection_id=original_connection_id or header.destination_cid, + original_connection_id=original_connection_id, + session_ticket_fetcher=self._session_ticket_fetcher, + session_ticket_handler=self._session_ticket_handler, + ) + protocol = self._create_protocol( + connection, stream_handler=self._stream_handler + ) + protocol.connection_made(self._transport) + + # register callbacks + protocol._connection_id_issued_handler = partial( + self._connection_id_issued, protocol=protocol + ) + protocol._connection_id_retired_handler = partial( + self._connection_id_retired, protocol=protocol + ) + protocol._connection_terminated_handler = partial( + self._connection_terminated, protocol=protocol + ) + + self._protocols[header.destination_cid] = protocol + self._protocols[connection.host_cid] = protocol + + if protocol is not None: + protocol.datagram_received(data, addr) + + def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol): + self._protocols[cid] = protocol + + def _connection_id_retired( + self, cid: bytes, protocol: QuicConnectionProtocol + ) -> None: + assert self._protocols[cid] == protocol + del self._protocols[cid] + + def _connection_terminated(self, protocol: QuicConnectionProtocol): + for cid, proto in list(self._protocols.items()): + if proto == protocol: + del self._protocols[cid] + + +async def serve( + host: str, + port: int, + *, + configuration: QuicConfiguration, + create_protocol: Callable = QuicConnectionProtocol, + session_ticket_fetcher: Optional[SessionTicketFetcher] = None, + session_ticket_handler: Optional[SessionTicketHandler] = None, + stateless_retry: bool = False, + stream_handler: QuicStreamHandler = None, +) -> QuicServer: + """ + Start a QUIC server at the given `host` and `port`. + + :func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration` + containing TLS certificate and private key as the ``configuration`` argument. + + :func:`serve` also accepts the following optional arguments: + + * ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that + manages the connection. It should be a callable or class accepting the same + arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning + an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass. + * ``session_ticket_fetcher`` is a callback which is invoked by the TLS + engine when a session ticket is presented by the peer. It should return + the session ticket with the specified ID or `None` if it is not found. + * ``session_ticket_handler`` is a callback which is invoked by the TLS + engine when a new session ticket is issued. It should store the session + ticket for future lookup. + * ``stateless_retry`` specifies whether a stateless retry should be + performed prior to handling new connections. + * ``stream_handler`` is a callback which is invoked whenever a stream is + created. It must accept two arguments: a :class:`asyncio.StreamReader` + and a :class:`asyncio.StreamWriter`. + """ + + loop = asyncio.get_event_loop() + + _, protocol = await loop.create_datagram_endpoint( + lambda: QuicServer( + configuration=configuration, + create_protocol=create_protocol, + session_ticket_fetcher=session_ticket_fetcher, + session_ticket_handler=session_ticket_handler, + stateless_retry=stateless_retry, + stream_handler=stream_handler, + ), + local_addr=(host, port), + ) + return cast(QuicServer, protocol) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/buffer.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/buffer.py new file mode 100644 index 000000000000..17a4298f954c --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/buffer.py @@ -0,0 +1,29 @@ +from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa + +UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF + + +def encode_uint_var(value: int) -> bytes: + """ + Encode a variable-length unsigned integer. + """ + buf = Buffer(capacity=8) + buf.push_uint_var(value) + return buf.data + + +def size_uint_var(value: int) -> int: + """ + Return the number of bytes required to encode the given value + as a QUIC variable-length unsigned integer. + """ + if value <= 0x3F: + return 1 + elif value <= 0x3FFF: + return 2 + elif value <= 0x3FFFFFFF: + return 4 + elif value <= 0x3FFFFFFFFFFFFFFF: + return 8 + else: + raise ValueError("Integer is too big for a variable-length integer") diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h0/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h0/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h0/connection.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h0/connection.py new file mode 100644 index 000000000000..2c8ddcec6b39 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h0/connection.py @@ -0,0 +1,63 @@ +from typing import Dict, List + +from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived +from aioquic.quic.connection import QuicConnection +from aioquic.quic.events import QuicEvent, StreamDataReceived + +H0_ALPN = ["hq-27", "hq-26", "hq-25"] + + +class H0Connection: + """ + An HTTP/0.9 connection object. + """ + + def __init__(self, quic: QuicConnection): + self._headers_received: Dict[int, bool] = {} + self._is_client = quic.configuration.is_client + self._quic = quic + + def handle_event(self, event: QuicEvent) -> List[H3Event]: + http_events: List[H3Event] = [] + + if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0: + data = event.data + if not self._headers_received.get(event.stream_id, False): + if self._is_client: + http_events.append( + HeadersReceived( + headers=[], stream_ended=False, stream_id=event.stream_id + ) + ) + else: + method, path = data.rstrip().split(b" ", 1) + http_events.append( + HeadersReceived( + headers=[(b":method", method), (b":path", path)], + stream_ended=False, + stream_id=event.stream_id, + ) + ) + data = b"" + self._headers_received[event.stream_id] = True + + http_events.append( + DataReceived( + data=data, stream_ended=event.end_stream, stream_id=event.stream_id + ) + ) + + return http_events + + def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None: + self._quic.send_stream_data(stream_id, data, end_stream) + + def send_headers( + self, stream_id: int, headers: Headers, end_stream: bool = False + ) -> None: + if self._is_client: + headers_dict = dict(headers) + data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n" + else: + data = b"" + self._quic.send_stream_data(stream_id, data, end_stream) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/connection.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/connection.py new file mode 100644 index 000000000000..ddf8c54283cd --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/connection.py @@ -0,0 +1,776 @@ +import logging +from enum import Enum, IntEnum +from typing import Dict, List, Optional, Set + +import pylsqpack + +from aioquic.buffer import Buffer, BufferReadError, encode_uint_var +from aioquic.h3.events import ( + DataReceived, + H3Event, + Headers, + HeadersReceived, + PushPromiseReceived, +) +from aioquic.h3.exceptions import NoAvailablePushIDError +from aioquic.quic.connection import QuicConnection, stream_is_unidirectional +from aioquic.quic.events import QuicEvent, StreamDataReceived +from aioquic.quic.logger import QuicLoggerTrace + +logger = logging.getLogger("http3") + +H3_ALPN = ["h3-27", "h3-26", "h3-25"] + + +class ErrorCode(IntEnum): + HTTP_NO_ERROR = 0x100 + HTTP_GENERAL_PROTOCOL_ERROR = 0x101 + HTTP_INTERNAL_ERROR = 0x102 + HTTP_STREAM_CREATION_ERROR = 0x103 + HTTP_CLOSED_CRITICAL_STREAM = 0x104 + HTTP_FRAME_UNEXPECTED = 0x105 + HTTP_FRAME_ERROR = 0x106 + HTTP_EXCESSIVE_LOAD = 0x107 + HTTP_ID_ERROR = 0x108 + HTTP_SETTINGS_ERROR = 0x109 + HTTP_MISSING_SETTINGS = 0x10A + HTTP_REQUEST_REJECTED = 0x10B + HTTP_REQUEST_CANCELLED = 0x10C + HTTP_REQUEST_INCOMPLETE = 0x10D + HTTP_EARLY_RESPONSE = 0x10E + HTTP_CONNECT_ERROR = 0x10F + HTTP_VERSION_FALLBACK = 0x110 + HTTP_QPACK_DECOMPRESSION_FAILED = 0x200 + HTTP_QPACK_ENCODER_STREAM_ERROR = 0x201 + HTTP_QPACK_DECODER_STREAM_ERROR = 0x202 + + +class FrameType(IntEnum): + DATA = 0x0 + HEADERS = 0x1 + PRIORITY = 0x2 + CANCEL_PUSH = 0x3 + SETTINGS = 0x4 + PUSH_PROMISE = 0x5 + GOAWAY = 0x7 + MAX_PUSH_ID = 0xD + DUPLICATE_PUSH = 0xE + + +class HeadersState(Enum): + INITIAL = 0 + AFTER_HEADERS = 1 + AFTER_TRAILERS = 2 + + +class Setting(IntEnum): + QPACK_MAX_TABLE_CAPACITY = 1 + SETTINGS_MAX_HEADER_LIST_SIZE = 6 + QPACK_BLOCKED_STREAMS = 7 + SETTINGS_NUM_PLACEHOLDERS = 9 + + +class StreamType(IntEnum): + CONTROL = 0 + PUSH = 1 + QPACK_ENCODER = 2 + QPACK_DECODER = 3 + + +class ProtocolError(Exception): + """ + Base class for protocol errors. + + These errors are not exposed to the API user, they are handled + in :meth:`H3Connection.handle_event`. + """ + + error_code = ErrorCode.HTTP_GENERAL_PROTOCOL_ERROR + + def __init__(self, reason_phrase: str = ""): + self.reason_phrase = reason_phrase + + +class QpackDecompressionFailed(ProtocolError): + error_code = ErrorCode.HTTP_QPACK_DECOMPRESSION_FAILED + + +class QpackDecoderStreamError(ProtocolError): + error_code = ErrorCode.HTTP_QPACK_DECODER_STREAM_ERROR + + +class QpackEncoderStreamError(ProtocolError): + error_code = ErrorCode.HTTP_QPACK_ENCODER_STREAM_ERROR + + +class StreamCreationError(ProtocolError): + error_code = ErrorCode.HTTP_STREAM_CREATION_ERROR + + +class FrameUnexpected(ProtocolError): + error_code = ErrorCode.HTTP_FRAME_UNEXPECTED + + +def encode_frame(frame_type: int, frame_data: bytes) -> bytes: + frame_length = len(frame_data) + buf = Buffer(capacity=frame_length + 16) + buf.push_uint_var(frame_type) + buf.push_uint_var(frame_length) + buf.push_bytes(frame_data) + return buf.data + + +def encode_settings(settings: Dict[int, int]) -> bytes: + buf = Buffer(capacity=1024) + for setting, value in settings.items(): + buf.push_uint_var(setting) + buf.push_uint_var(value) + return buf.data + + +def parse_max_push_id(data: bytes) -> int: + buf = Buffer(data=data) + max_push_id = buf.pull_uint_var() + assert buf.eof() + return max_push_id + + +def parse_settings(data: bytes) -> Dict[int, int]: + buf = Buffer(data=data) + settings = [] + while not buf.eof(): + setting = buf.pull_uint_var() + value = buf.pull_uint_var() + settings.append((setting, value)) + return dict(settings) + + +def qlog_encode_data_frame(byte_length: int, stream_id: int) -> Dict: + return { + "byte_length": str(byte_length), + "frame": {"frame_type": "data"}, + "stream_id": str(stream_id), + } + + +def qlog_encode_headers(headers: Headers) -> List[Dict]: + return [ + {"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers + ] + + +def qlog_encode_headers_frame( + byte_length: int, headers: Headers, stream_id: int +) -> Dict: + return { + "byte_length": str(byte_length), + "frame": {"frame_type": "headers", "headers": qlog_encode_headers(headers)}, + "stream_id": str(stream_id), + } + + +def qlog_encode_push_promise_frame( + byte_length: int, headers: Headers, push_id: int, stream_id: int +) -> Dict: + return { + "byte_length": str(byte_length), + "frame": { + "frame_type": "push_promise", + "headers": qlog_encode_headers(headers), + "push_id": str(push_id), + }, + "stream_id": str(stream_id), + } + + +class H3Stream: + def __init__(self, stream_id: int) -> None: + self.blocked = False + self.blocked_frame_size: Optional[int] = None + self.buffer = b"" + self.ended = False + self.frame_size: Optional[int] = None + self.frame_type: Optional[int] = None + self.headers_recv_state: HeadersState = HeadersState.INITIAL + self.headers_send_state: HeadersState = HeadersState.INITIAL + self.push_id: Optional[int] = None + self.stream_id = stream_id + self.stream_type: Optional[int] = None + + +class H3Connection: + """ + A low-level HTTP/3 connection object. + + :param quic: A :class:`~aioquic.connection.QuicConnection` instance. + """ + + def __init__(self, quic: QuicConnection): + self._max_table_capacity = 4096 + self._blocked_streams = 16 + + self._is_client = quic.configuration.is_client + self._is_done = False + self._quic = quic + self._quic_logger: Optional[QuicLoggerTrace] = quic._quic_logger + self._decoder = pylsqpack.Decoder( + self._max_table_capacity, self._blocked_streams + ) + self._decoder_bytes_received = 0 + self._decoder_bytes_sent = 0 + self._encoder = pylsqpack.Encoder() + self._encoder_bytes_received = 0 + self._encoder_bytes_sent = 0 + self._stream: Dict[int, H3Stream] = {} + + self._max_push_id: Optional[int] = 8 if self._is_client else None + self._next_push_id: int = 0 + + self._local_control_stream_id: Optional[int] = None + self._local_decoder_stream_id: Optional[int] = None + self._local_encoder_stream_id: Optional[int] = None + + self._peer_control_stream_id: Optional[int] = None + self._peer_decoder_stream_id: Optional[int] = None + self._peer_encoder_stream_id: Optional[int] = None + + self._init_connection() + + def handle_event(self, event: QuicEvent) -> List[H3Event]: + """ + Handle a QUIC event and return a list of HTTP events. + + :param event: The QUIC event to handle. + """ + if isinstance(event, StreamDataReceived) and not self._is_done: + stream_id = event.stream_id + stream = self._get_or_create_stream(stream_id) + try: + if stream_id % 4 == 0: + return self._receive_request_or_push_data( + stream, event.data, event.end_stream + ) + elif stream_is_unidirectional(stream_id): + return self._receive_stream_data_uni( + stream, event.data, event.end_stream + ) + except ProtocolError as exc: + self._is_done = True + self._quic.close( + error_code=exc.error_code, reason_phrase=exc.reason_phrase + ) + return [] + + def send_push_promise(self, stream_id: int, headers: Headers) -> int: + """ + Send a push promise related to the specified stream. + + Returns the stream ID on which headers and data can be sent. + + :param stream_id: The stream ID on which to send the data. + :param headers: The HTTP request headers for this push. + """ + assert not self._is_client, "Only servers may send a push promise." + if self._max_push_id is None or self._next_push_id >= self._max_push_id: + raise NoAvailablePushIDError + + # send push promise + push_id = self._next_push_id + self._next_push_id += 1 + self._quic.send_stream_data( + stream_id, + encode_frame( + FrameType.PUSH_PROMISE, + encode_uint_var(push_id) + self._encode_headers(stream_id, headers), + ), + ) + + #  create push stream + push_stream_id = self._create_uni_stream(StreamType.PUSH) + self._quic.send_stream_data(push_stream_id, encode_uint_var(push_id)) + + return push_stream_id + + def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None: + """ + Send data on the given stream. + + To retrieve datagram which need to be sent over the network call the QUIC + connection's :meth:`~aioquic.connection.QuicConnection.datagrams_to_send` + method. + + :param stream_id: The stream ID on which to send the data. + :param data: The data to send. + :param end_stream: Whether to end the stream. + """ + # check DATA frame is allowed + stream = self._get_or_create_stream(stream_id) + if stream.headers_send_state != HeadersState.AFTER_HEADERS: + raise FrameUnexpected("DATA frame is not allowed in this state") + + # log frame + if self._quic_logger is not None: + self._quic_logger.log_event( + category="http", + event="frame_created", + data=qlog_encode_data_frame(byte_length=len(data), stream_id=stream_id), + ) + + self._quic.send_stream_data( + stream_id, encode_frame(FrameType.DATA, data), end_stream + ) + + def send_headers( + self, stream_id: int, headers: Headers, end_stream: bool = False + ) -> None: + """ + Send headers on the given stream. + + To retrieve datagram which need to be sent over the network call the QUIC + connection's :meth:`~aioquic.connection.QuicConnection.datagrams_to_send` + method. + + :param stream_id: The stream ID on which to send the headers. + :param headers: The HTTP headers to send. + :param end_stream: Whether to end the stream. + """ + # check HEADERS frame is allowed + stream = self._get_or_create_stream(stream_id) + if stream.headers_send_state == HeadersState.AFTER_TRAILERS: + raise FrameUnexpected("HEADERS frame is not allowed in this state") + + frame_data = self._encode_headers(stream_id, headers) + + # log frame + if self._quic_logger is not None: + self._quic_logger.log_event( + category="http", + event="frame_created", + data=qlog_encode_headers_frame( + byte_length=len(frame_data), headers=headers, stream_id=stream_id + ), + ) + + # update state and send headers + if stream.headers_send_state == HeadersState.INITIAL: + stream.headers_send_state = HeadersState.AFTER_HEADERS + else: + stream.headers_send_state = HeadersState.AFTER_TRAILERS + self._quic.send_stream_data( + stream_id, encode_frame(FrameType.HEADERS, frame_data), end_stream + ) + + def _create_uni_stream(self, stream_type: int) -> int: + """ + Create an unidirectional stream of the given type. + """ + stream_id = self._quic.get_next_available_stream_id(is_unidirectional=True) + self._quic.send_stream_data(stream_id, encode_uint_var(stream_type)) + return stream_id + + def _decode_headers(self, stream_id: int, frame_data: Optional[bytes]) -> Headers: + """ + Decode a HEADERS block and send decoder updates on the decoder stream. + + This is called with frame_data=None when a stream becomes unblocked. + """ + try: + if frame_data is None: + decoder, headers = self._decoder.resume_header(stream_id) + else: + decoder, headers = self._decoder.feed_header(stream_id, frame_data) + self._decoder_bytes_sent += len(decoder) + self._quic.send_stream_data(self._local_decoder_stream_id, decoder) + except pylsqpack.DecompressionFailed as exc: + raise QpackDecompressionFailed() from exc + + return headers + + def _encode_headers(self, stream_id: int, headers: Headers) -> bytes: + """ + Encode a HEADERS block and send encoder updates on the encoder stream. + """ + encoder, frame_data = self._encoder.encode(stream_id, headers) + self._encoder_bytes_sent += len(encoder) + self._quic.send_stream_data(self._local_encoder_stream_id, encoder) + return frame_data + + def _get_or_create_stream(self, stream_id: int) -> H3Stream: + if stream_id not in self._stream: + self._stream[stream_id] = H3Stream(stream_id) + return self._stream[stream_id] + + def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None: + """ + Handle a frame received on the peer's control stream. + """ + if frame_type == FrameType.SETTINGS: + settings = parse_settings(frame_data) + encoder = self._encoder.apply_settings( + max_table_capacity=settings.get(Setting.QPACK_MAX_TABLE_CAPACITY, 0), + blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0), + ) + self._quic.send_stream_data(self._local_encoder_stream_id, encoder) + elif frame_type == FrameType.MAX_PUSH_ID: + if self._is_client: + raise FrameUnexpected("Servers must not send MAX_PUSH_ID") + self._max_push_id = parse_max_push_id(frame_data) + elif frame_type in ( + FrameType.DATA, + FrameType.HEADERS, + FrameType.PUSH_PROMISE, + FrameType.DUPLICATE_PUSH, + ): + raise FrameUnexpected("Invalid frame type on control stream") + + def _handle_request_or_push_frame( + self, + frame_type: int, + frame_data: Optional[bytes], + stream: H3Stream, + stream_ended: bool, + ) -> List[H3Event]: + """ + Handle a frame received on a request or push stream. + """ + http_events: List[H3Event] = [] + + if frame_type == FrameType.DATA: + # check DATA frame is allowed + if stream.headers_recv_state != HeadersState.AFTER_HEADERS: + raise FrameUnexpected("DATA frame is not allowed in this state") + + if stream_ended or frame_data: + http_events.append( + DataReceived( + data=frame_data, + push_id=stream.push_id, + stream_ended=stream_ended, + stream_id=stream.stream_id, + ) + ) + elif frame_type == FrameType.HEADERS: + # check HEADERS frame is allowed + if stream.headers_recv_state == HeadersState.AFTER_TRAILERS: + raise FrameUnexpected("HEADERS frame is not allowed in this state") + + # try to decode HEADERS, may raise pylsqpack.StreamBlocked + headers = self._decode_headers(stream.stream_id, frame_data) + + # log frame + if self._quic_logger is not None: + self._quic_logger.log_event( + category="http", + event="frame_parsed", + data=qlog_encode_headers_frame( + byte_length=stream.blocked_frame_size + if frame_data is None + else len(frame_data), + headers=headers, + stream_id=stream.stream_id, + ), + ) + + # update state and emit headers + if stream.headers_recv_state == HeadersState.INITIAL: + stream.headers_recv_state = HeadersState.AFTER_HEADERS + else: + stream.headers_recv_state = HeadersState.AFTER_TRAILERS + http_events.append( + HeadersReceived( + headers=headers, + push_id=stream.push_id, + stream_id=stream.stream_id, + stream_ended=stream_ended, + ) + ) + elif stream.frame_type == FrameType.PUSH_PROMISE and stream.push_id is None: + if not self._is_client: + raise FrameUnexpected("Clients must not send PUSH_PROMISE") + frame_buf = Buffer(data=frame_data) + push_id = frame_buf.pull_uint_var() + headers = self._decode_headers( + stream.stream_id, frame_data[frame_buf.tell() :] + ) + + # log frame + if self._quic_logger is not None: + self._quic_logger.log_event( + category="http", + event="frame_parsed", + data=qlog_encode_push_promise_frame( + byte_length=len(frame_data), + headers=headers, + push_id=push_id, + stream_id=stream.stream_id, + ), + ) + + # emit event + http_events.append( + PushPromiseReceived( + headers=headers, push_id=push_id, stream_id=stream.stream_id + ) + ) + elif frame_type in ( + FrameType.PRIORITY, + FrameType.CANCEL_PUSH, + FrameType.SETTINGS, + FrameType.PUSH_PROMISE, + FrameType.GOAWAY, + FrameType.MAX_PUSH_ID, + FrameType.DUPLICATE_PUSH, + ): + raise FrameUnexpected( + "Invalid frame type on request stream" + if stream.push_id is None + else "Invalid frame type on push stream" + ) + + return http_events + + def _init_connection(self) -> None: + # send our settings + self._local_control_stream_id = self._create_uni_stream(StreamType.CONTROL) + self._quic.send_stream_data( + self._local_control_stream_id, + encode_frame( + FrameType.SETTINGS, + encode_settings( + { + Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity, + Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams, + } + ), + ), + ) + if self._is_client and self._max_push_id is not None: + self._quic.send_stream_data( + self._local_control_stream_id, + encode_frame(FrameType.MAX_PUSH_ID, encode_uint_var(self._max_push_id)), + ) + + # create encoder and decoder streams + self._local_encoder_stream_id = self._create_uni_stream( + StreamType.QPACK_ENCODER + ) + self._local_decoder_stream_id = self._create_uni_stream( + StreamType.QPACK_DECODER + ) + + def _receive_request_or_push_data( + self, stream: H3Stream, data: bytes, stream_ended: bool + ) -> List[H3Event]: + """ + Handle data received on a request or push stream. + """ + http_events: List[H3Event] = [] + + stream.buffer += data + if stream_ended: + stream.ended = True + if stream.blocked: + return http_events + + # shortcut for DATA frame fragments + if ( + stream.frame_type == FrameType.DATA + and stream.frame_size is not None + and len(stream.buffer) < stream.frame_size + ): + http_events.append( + DataReceived( + data=stream.buffer, + push_id=stream.push_id, + stream_id=stream.stream_id, + stream_ended=False, + ) + ) + stream.frame_size -= len(stream.buffer) + stream.buffer = b"" + return http_events + + # handle lone FIN + if stream_ended and not stream.buffer: + http_events.append( + DataReceived( + data=b"", + push_id=stream.push_id, + stream_id=stream.stream_id, + stream_ended=True, + ) + ) + return http_events + + buf = Buffer(data=stream.buffer) + consumed = 0 + + while not buf.eof(): + # fetch next frame header + if stream.frame_size is None: + try: + stream.frame_type = buf.pull_uint_var() + stream.frame_size = buf.pull_uint_var() + except BufferReadError: + break + consumed = buf.tell() + + # log frame + if ( + self._quic_logger is not None + and stream.frame_type == FrameType.DATA + ): + self._quic_logger.log_event( + category="http", + event="frame_parsed", + data=qlog_encode_data_frame( + byte_length=stream.frame_size, stream_id=stream.stream_id + ), + ) + + # check how much data is available + chunk_size = min(stream.frame_size, buf.capacity - consumed) + if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size: + break + + # read available data + frame_data = buf.pull_bytes(chunk_size) + consumed = buf.tell() + + # detect end of frame + stream.frame_size -= chunk_size + if not stream.frame_size: + stream.frame_size = None + + try: + http_events.extend( + self._handle_request_or_push_frame( + frame_type=stream.frame_type, + frame_data=frame_data, + stream=stream, + stream_ended=stream.ended and buf.eof(), + ) + ) + except pylsqpack.StreamBlocked: + stream.blocked = True + stream.blocked_frame_size = len(frame_data) + break + + # remove processed data from buffer + stream.buffer = stream.buffer[consumed:] + + return http_events + + def _receive_stream_data_uni( + self, stream: H3Stream, data: bytes, stream_ended: bool + ) -> List[H3Event]: + http_events: List[H3Event] = [] + + stream.buffer += data + if stream_ended: + stream.ended = True + + buf = Buffer(data=stream.buffer) + consumed = 0 + unblocked_streams: Set[int] = set() + + while stream.stream_type == StreamType.PUSH or not buf.eof(): + # fetch stream type for unidirectional streams + if stream.stream_type is None: + try: + stream.stream_type = buf.pull_uint_var() + except BufferReadError: + break + consumed = buf.tell() + + # check unicity + if stream.stream_type == StreamType.CONTROL: + if self._peer_control_stream_id is not None: + raise StreamCreationError("Only one control stream is allowed") + self._peer_control_stream_id = stream.stream_id + elif stream.stream_type == StreamType.QPACK_DECODER: + if self._peer_decoder_stream_id is not None: + raise StreamCreationError( + "Only one QPACK decoder stream is allowed" + ) + self._peer_decoder_stream_id = stream.stream_id + elif stream.stream_type == StreamType.QPACK_ENCODER: + if self._peer_encoder_stream_id is not None: + raise StreamCreationError( + "Only one QPACK encoder stream is allowed" + ) + self._peer_encoder_stream_id = stream.stream_id + + if stream.stream_type == StreamType.CONTROL: + # fetch next frame + try: + frame_type = buf.pull_uint_var() + frame_length = buf.pull_uint_var() + frame_data = buf.pull_bytes(frame_length) + except BufferReadError: + break + consumed = buf.tell() + + self._handle_control_frame(frame_type, frame_data) + elif stream.stream_type == StreamType.PUSH: + # fetch push id + if stream.push_id is None: + try: + stream.push_id = buf.pull_uint_var() + except BufferReadError: + break + consumed = buf.tell() + + # remove processed data from buffer + stream.buffer = stream.buffer[consumed:] + + return self._receive_request_or_push_data(stream, b"", stream_ended) + elif stream.stream_type == StreamType.QPACK_DECODER: + # feed unframed data to decoder + data = buf.pull_bytes(buf.capacity - buf.tell()) + consumed = buf.tell() + try: + self._encoder.feed_decoder(data) + except pylsqpack.DecoderStreamError as exc: + raise QpackDecoderStreamError() from exc + self._decoder_bytes_received += len(data) + elif stream.stream_type == StreamType.QPACK_ENCODER: + # feed unframed data to encoder + data = buf.pull_bytes(buf.capacity - buf.tell()) + consumed = buf.tell() + try: + unblocked_streams.update(self._decoder.feed_encoder(data)) + except pylsqpack.EncoderStreamError as exc: + raise QpackEncoderStreamError() from exc + self._encoder_bytes_received += len(data) + else: + # unknown stream type, discard data + buf.seek(buf.capacity) + consumed = buf.tell() + + # remove processed data from buffer + stream.buffer = stream.buffer[consumed:] + + # process unblocked streams + for stream_id in unblocked_streams: + stream = self._stream[stream_id] + + # resume headers + http_events.extend( + self._handle_request_or_push_frame( + frame_type=FrameType.HEADERS, + frame_data=None, + stream=stream, + stream_ended=stream.ended and not stream.buffer, + ) + ) + stream.blocked = False + stream.blocked_frame_size = None + + # resume processing + if stream.buffer: + http_events.extend( + self._receive_request_or_push_data(stream, b"", stream.ended) + ) + + return http_events diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/events.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/events.py new file mode 100644 index 000000000000..cd21c5abc427 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/events.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +Headers = List[Tuple[bytes, bytes]] + + +class H3Event: + """ + Base class for HTTP/3 events. + """ + + +@dataclass +class DataReceived(H3Event): + """ + The DataReceived event is fired whenever data is received on a stream from + the remote peer. + """ + + data: bytes + "The data which was received." + + stream_id: int + "The ID of the stream the data was received for." + + stream_ended: bool + "Whether the STREAM frame had the FIN bit set." + + push_id: Optional[int] = None + "The Push ID or `None` if this is not a push." + + +@dataclass +class HeadersReceived(H3Event): + """ + The HeadersReceived event is fired whenever headers are received. + """ + + headers: Headers + "The headers." + + stream_id: int + "The ID of the stream the headers were received for." + + stream_ended: bool + "Whether the STREAM frame had the FIN bit set." + + push_id: Optional[int] = None + "The Push ID or `None` if this is not a push." + + +@dataclass +class PushPromiseReceived(H3Event): + """ + The PushedStreamReceived event is fired whenever a pushed stream has been + received from the remote peer. + """ + + headers: Headers + "The request headers." + + push_id: int + "The Push ID of the push promise." + + stream_id: int + "The Stream ID of the stream that the push is related to." diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/exceptions.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/exceptions.py new file mode 100644 index 000000000000..d5200c1d61fa --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/h3/exceptions.py @@ -0,0 +1,10 @@ +class H3Error(Exception): + """ + Base class for HTTP/3 exceptions. + """ + + +class NoAvailablePushIDError(H3Error): + """ + There are no available push IDs left. + """ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/py.typed b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/py.typed new file mode 100644 index 000000000000..f5642f79f21d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/py.typed @@ -0,0 +1 @@ +Marker diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/configuration.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/configuration.py new file mode 100644 index 000000000000..f861f82979e9 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/configuration.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass, field +from os import PathLike +from typing import Any, List, Optional, TextIO, Union + +from ..tls import SessionTicket, load_pem_private_key, load_pem_x509_certificates +from .logger import QuicLogger +from .packet import QuicProtocolVersion + + +@dataclass +class QuicConfiguration: + """ + A QUIC configuration. + """ + + alpn_protocols: Optional[List[str]] = None + """ + A list of supported ALPN protocols. + """ + + connection_id_length: int = 8 + """ + The length in bytes of local connection IDs. + """ + + idle_timeout: float = 60.0 + """ + The idle timeout in seconds. + + The connection is terminated if nothing is received for the given duration. + """ + + is_client: bool = True + """ + Whether this is the client side of the QUIC connection. + """ + + max_data: int = 1048576 + """ + Connection-wide flow control limit. + """ + + max_stream_data: int = 1048576 + """ + Per-stream flow control limit. + """ + + quic_logger: Optional[QuicLogger] = None + """ + The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to. + """ + + secrets_log_file: TextIO = None + """ + A file-like object in which to log traffic secrets. + + This is useful to analyze traffic captures with Wireshark. + """ + + server_name: Optional[str] = None + """ + The server name to send during the TLS handshake the Server Name Indication. + + .. note:: This is only used by clients. + """ + + session_ticket: Optional[SessionTicket] = None + """ + The TLS session ticket which should be used for session resumption. + """ + + cadata: Optional[bytes] = None + cafile: Optional[str] = None + capath: Optional[str] = None + certificate: Any = None + certificate_chain: List[Any] = field(default_factory=list) + max_datagram_frame_size: Optional[int] = None + private_key: Any = None + quantum_readiness_test: bool = False + supported_versions: List[int] = field( + default_factory=lambda: [ + QuicProtocolVersion.DRAFT_27, + QuicProtocolVersion.DRAFT_26, + QuicProtocolVersion.DRAFT_25, + ] + ) + verify_mode: Optional[int] = None + + def load_cert_chain( + self, + certfile: PathLike, + keyfile: Optional[PathLike] = None, + password: Optional[Union[bytes, str]] = None, + ) -> None: + """ + Load a private key and the corresponding certificate. + """ + with open(certfile, "rb") as fp: + certificates = load_pem_x509_certificates(fp.read()) + self.certificate = certificates[0] + self.certificate_chain = certificates[1:] + + if keyfile is not None: + with open(keyfile, "rb") as fp: + self.private_key = load_pem_private_key( + fp.read(), + password=password.encode("utf8") + if isinstance(password, str) + else password, + ) + + def load_verify_locations( + self, + cafile: Optional[str] = None, + capath: Optional[str] = None, + cadata: Optional[bytes] = None, + ) -> None: + """ + Load a set of "certification authority" (CA) certificates used to + validate other peers' certificates. + """ + self.cafile = cafile + self.capath = capath + self.cadata = cadata diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/connection.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/connection.py new file mode 100644 index 000000000000..461df4fc1916 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/connection.py @@ -0,0 +1,2652 @@ +import binascii +import logging +import os +from collections import deque +from dataclasses import dataclass +from enum import Enum +from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Tuple + +from .. import tls +from ..buffer import UINT_VAR_MAX, Buffer, BufferReadError, size_uint_var +from . import events +from .configuration import QuicConfiguration +from .crypto import CryptoError, CryptoPair, KeyUnavailableError +from .logger import QuicLoggerTrace +from .packet import ( + NON_ACK_ELICITING_FRAME_TYPES, + PACKET_TYPE_HANDSHAKE, + PACKET_TYPE_INITIAL, + PACKET_TYPE_ONE_RTT, + PACKET_TYPE_RETRY, + PACKET_TYPE_ZERO_RTT, + PROBING_FRAME_TYPES, + RETRY_INTEGRITY_TAG_SIZE, + QuicErrorCode, + QuicFrameType, + QuicProtocolVersion, + QuicStreamFrame, + QuicTransportParameters, + get_retry_integrity_tag, + get_spin_bit, + is_long_header, + pull_ack_frame, + pull_quic_header, + pull_quic_transport_parameters, + push_ack_frame, + push_quic_transport_parameters, +) +from .packet_builder import ( + PACKET_MAX_SIZE, + QuicDeliveryState, + QuicPacketBuilder, + QuicPacketBuilderStop, +) +from .recovery import K_GRANULARITY, QuicPacketRecovery, QuicPacketSpace +from .stream import QuicStream + +logger = logging.getLogger("quic") + +EPOCH_SHORTCUTS = { + "I": tls.Epoch.INITIAL, + "H": tls.Epoch.HANDSHAKE, + "0": tls.Epoch.ZERO_RTT, + "1": tls.Epoch.ONE_RTT, +} +MAX_EARLY_DATA = 0xFFFFFFFF +SECRETS_LABELS = [ + [ + None, + "QUIC_CLIENT_EARLY_TRAFFIC_SECRET", + "QUIC_CLIENT_HANDSHAKE_TRAFFIC_SECRET", + "QUIC_CLIENT_TRAFFIC_SECRET_0", + ], + [ + None, + None, + "QUIC_SERVER_HANDSHAKE_TRAFFIC_SECRET", + "QUIC_SERVER_TRAFFIC_SECRET_0", + ], +] +STREAM_FLAGS = 0x07 + +NetworkAddress = Any + +# frame sizes +ACK_FRAME_CAPACITY = 64 # FIXME: this is arbitrary! +APPLICATION_CLOSE_FRAME_CAPACITY = 1 + 8 + 8 # + reason length +HANDSHAKE_DONE_FRAME_CAPACITY = 1 +MAX_DATA_FRAME_CAPACITY = 1 + 8 +MAX_STREAM_DATA_FRAME_CAPACITY = 1 + 8 + 8 +NEW_CONNECTION_ID_FRAME_CAPACITY = 1 + 8 + 8 + 1 + 20 + 16 +PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8 +PATH_RESPONSE_FRAME_CAPACITY = 1 + 8 +PING_FRAME_CAPACITY = 1 +RETIRE_CONNECTION_ID_CAPACITY = 1 + 8 +STREAMS_BLOCKED_CAPACITY = 1 + 8 +TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 8 + 8 + 8 # + reason length + + +def EPOCHS(shortcut: str) -> FrozenSet[tls.Epoch]: + return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) + + +def dump_cid(cid: bytes) -> str: + return binascii.hexlify(cid).decode("ascii") + + +def get_epoch(packet_type: int) -> tls.Epoch: + if packet_type == PACKET_TYPE_INITIAL: + return tls.Epoch.INITIAL + elif packet_type == PACKET_TYPE_ZERO_RTT: + return tls.Epoch.ZERO_RTT + elif packet_type == PACKET_TYPE_HANDSHAKE: + return tls.Epoch.HANDSHAKE + else: + return tls.Epoch.ONE_RTT + + +def stream_is_client_initiated(stream_id: int) -> bool: + """ + Returns True if the stream is client initiated. + """ + return not (stream_id & 1) + + +def stream_is_unidirectional(stream_id: int) -> bool: + """ + Returns True if the stream is unidirectional. + """ + return bool(stream_id & 2) + + +class QuicConnectionError(Exception): + def __init__(self, error_code: int, frame_type: int, reason_phrase: str): + self.error_code = error_code + self.frame_type = frame_type + self.reason_phrase = reason_phrase + + def __str__(self) -> str: + s = "Error: %d, reason: %s" % (self.error_code, self.reason_phrase) + if self.frame_type is not None: + s += ", frame_type: %s" % self.frame_type + return s + + +class QuicConnectionAdapter(logging.LoggerAdapter): + def process(self, msg: str, kwargs: Any) -> Tuple[str, Any]: + return "[%s] %s" % (self.extra["id"], msg), kwargs + + +@dataclass +class QuicConnectionId: + cid: bytes + sequence_number: int + stateless_reset_token: bytes = b"" + was_sent: bool = False + + +class QuicConnectionState(Enum): + FIRSTFLIGHT = 0 + CONNECTED = 1 + CLOSING = 2 + DRAINING = 3 + TERMINATED = 4 + + +@dataclass +class QuicNetworkPath: + addr: NetworkAddress + bytes_received: int = 0 + bytes_sent: int = 0 + is_validated: bool = False + local_challenge: Optional[bytes] = None + remote_challenge: Optional[bytes] = None + + def can_send(self, size: int) -> bool: + return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received + + +@dataclass +class QuicReceiveContext: + epoch: tls.Epoch + host_cid: bytes + network_path: QuicNetworkPath + quic_logger_frames: Optional[List[Any]] + time: float + + +END_STATES = frozenset( + [ + QuicConnectionState.CLOSING, + QuicConnectionState.DRAINING, + QuicConnectionState.TERMINATED, + ] +) + + +class QuicConnection: + """ + A QUIC connection. + + The state machine is driven by three kinds of sources: + + - the API user requesting data to be send out (see :meth:`connect`, + :meth:`send_ping`, :meth:`send_datagram_data` and :meth:`send_stream_data`) + - data being received from the network (see :meth:`receive_datagram`) + - a timer firing (see :meth:`handle_timer`) + + :param configuration: The QUIC configuration to use. + """ + + def __init__( + self, + *, + configuration: QuicConfiguration, + logger_connection_id: Optional[bytes] = None, + original_connection_id: Optional[bytes] = None, + session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None, + session_ticket_handler: Optional[tls.SessionTicketHandler] = None, + ) -> None: + if configuration.is_client: + assert ( + original_connection_id is None + ), "Cannot set original_connection_id for a client" + else: + assert ( + configuration.certificate is not None + ), "SSL certificate is required for a server" + assert ( + configuration.private_key is not None + ), "SSL private key is required for a server" + + # configuration + self._configuration = configuration + self._is_client = configuration.is_client + + self._ack_delay = K_GRANULARITY + self._close_at: Optional[float] = None + self._close_event: Optional[events.ConnectionTerminated] = None + self._connect_called = False + self._cryptos: Dict[tls.Epoch, CryptoPair] = {} + self._crypto_buffers: Dict[tls.Epoch, Buffer] = {} + self._crypto_streams: Dict[tls.Epoch, QuicStream] = {} + self._events: Deque[events.QuicEvent] = deque() + self._handshake_complete = False + self._handshake_confirmed = False + self._host_cids = [ + QuicConnectionId( + cid=os.urandom(configuration.connection_id_length), + sequence_number=0, + stateless_reset_token=os.urandom(16), + was_sent=True, + ) + ] + self.host_cid = self._host_cids[0].cid + self._host_cid_seq = 1 + self._local_ack_delay_exponent = 3 + self._local_active_connection_id_limit = 8 + self._local_max_data = configuration.max_data + self._local_max_data_sent = configuration.max_data + self._local_max_data_used = 0 + self._local_max_stream_data_bidi_local = configuration.max_stream_data + self._local_max_stream_data_bidi_remote = configuration.max_stream_data + self._local_max_stream_data_uni = configuration.max_stream_data + self._local_max_streams_bidi = 128 + self._local_max_streams_uni = 128 + self._loss_at: Optional[float] = None + self._network_paths: List[QuicNetworkPath] = [] + self._original_connection_id = original_connection_id + self._pacing_at: Optional[float] = None + self._packet_number = 0 + self._parameters_received = False + self._peer_cid = os.urandom(configuration.connection_id_length) + self._peer_cid_seq: Optional[int] = None + self._peer_cid_available: List[QuicConnectionId] = [] + self._peer_token = b"" + self._quic_logger: Optional[QuicLoggerTrace] = None + self._remote_ack_delay_exponent = 3 + self._remote_active_connection_id_limit = 0 + self._remote_idle_timeout = 0.0 # seconds + self._remote_max_data = 0 + self._remote_max_data_used = 0 + self._remote_max_datagram_frame_size: Optional[int] = None + self._remote_max_stream_data_bidi_local = 0 + self._remote_max_stream_data_bidi_remote = 0 + self._remote_max_stream_data_uni = 0 + self._remote_max_streams_bidi = 0 + self._remote_max_streams_uni = 0 + self._spaces: Dict[tls.Epoch, QuicPacketSpace] = {} + self._spin_bit = False + self._spin_highest_pn = 0 + self._state = QuicConnectionState.FIRSTFLIGHT + self._stateless_retry_count = 0 + self._streams: Dict[int, QuicStream] = {} + self._streams_blocked_bidi: List[QuicStream] = [] + self._streams_blocked_uni: List[QuicStream] = [] + self._version: Optional[int] = None + + # logging + if logger_connection_id is None: + logger_connection_id = self._peer_cid + self._logger = QuicConnectionAdapter( + logger, {"id": dump_cid(logger_connection_id)} + ) + if configuration.quic_logger: + self._quic_logger = configuration.quic_logger.start_trace( + is_client=configuration.is_client, odcid=logger_connection_id + ) + + # loss recovery + self._loss = QuicPacketRecovery( + is_client_without_1rtt=self._is_client, + quic_logger=self._quic_logger, + send_probe=self._send_probe, + ) + + # things to send + self._close_pending = False + self._datagrams_pending: Deque[bytes] = deque() + self._handshake_done_pending = False + self._ping_pending: List[int] = [] + self._probe_pending = False + self._retire_connection_ids: List[int] = [] + self._streams_blocked_pending = False + + # callbacks + self._session_ticket_fetcher = session_ticket_fetcher + self._session_ticket_handler = session_ticket_handler + + # frame handlers + self.__frame_handlers = { + 0x00: (self._handle_padding_frame, EPOCHS("IH01")), + 0x01: (self._handle_ping_frame, EPOCHS("IH01")), + 0x02: (self._handle_ack_frame, EPOCHS("IH1")), + 0x03: (self._handle_ack_frame, EPOCHS("IH1")), + 0x04: (self._handle_reset_stream_frame, EPOCHS("01")), + 0x05: (self._handle_stop_sending_frame, EPOCHS("01")), + 0x06: (self._handle_crypto_frame, EPOCHS("IH1")), + 0x07: (self._handle_new_token_frame, EPOCHS("1")), + 0x08: (self._handle_stream_frame, EPOCHS("01")), + 0x09: (self._handle_stream_frame, EPOCHS("01")), + 0x0A: (self._handle_stream_frame, EPOCHS("01")), + 0x0B: (self._handle_stream_frame, EPOCHS("01")), + 0x0C: (self._handle_stream_frame, EPOCHS("01")), + 0x0D: (self._handle_stream_frame, EPOCHS("01")), + 0x0E: (self._handle_stream_frame, EPOCHS("01")), + 0x0F: (self._handle_stream_frame, EPOCHS("01")), + 0x10: (self._handle_max_data_frame, EPOCHS("01")), + 0x11: (self._handle_max_stream_data_frame, EPOCHS("01")), + 0x12: (self._handle_max_streams_bidi_frame, EPOCHS("01")), + 0x13: (self._handle_max_streams_uni_frame, EPOCHS("01")), + 0x14: (self._handle_data_blocked_frame, EPOCHS("01")), + 0x15: (self._handle_stream_data_blocked_frame, EPOCHS("01")), + 0x16: (self._handle_streams_blocked_frame, EPOCHS("01")), + 0x17: (self._handle_streams_blocked_frame, EPOCHS("01")), + 0x18: (self._handle_new_connection_id_frame, EPOCHS("01")), + 0x19: (self._handle_retire_connection_id_frame, EPOCHS("01")), + 0x1A: (self._handle_path_challenge_frame, EPOCHS("01")), + 0x1B: (self._handle_path_response_frame, EPOCHS("01")), + 0x1C: (self._handle_connection_close_frame, EPOCHS("IH1")), + 0x1D: (self._handle_connection_close_frame, EPOCHS("1")), + 0x1E: (self._handle_handshake_done_frame, EPOCHS("1")), + 0x30: (self._handle_datagram_frame, EPOCHS("01")), + 0x31: (self._handle_datagram_frame, EPOCHS("01")), + } + + @property + def configuration(self) -> QuicConfiguration: + return self._configuration + + def change_connection_id(self) -> None: + """ + Switch to the next available connection ID and retire + the previous one. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. + """ + if self._peer_cid_available: + # retire previous CID + self._logger.debug( + "Retiring CID %s (%d)", dump_cid(self._peer_cid), self._peer_cid_seq + ) + self._retire_connection_ids.append(self._peer_cid_seq) + + # assign new CID + connection_id = self._peer_cid_available.pop(0) + self._peer_cid_seq = connection_id.sequence_number + self._peer_cid = connection_id.cid + self._logger.debug( + "Switching to CID %s (%d)", dump_cid(self._peer_cid), self._peer_cid_seq + ) + + def close( + self, + error_code: int = QuicErrorCode.NO_ERROR, + frame_type: Optional[int] = None, + reason_phrase: str = "", + ) -> None: + """ + Close the connection. + + :param error_code: An error code indicating why the connection is + being closed. + :param reason_phrase: A human-readable explanation of why the + connection is being closed. + """ + if self._state not in END_STATES: + self._close_event = events.ConnectionTerminated( + error_code=error_code, + frame_type=frame_type, + reason_phrase=reason_phrase, + ) + self._close_pending = True + + def connect(self, addr: NetworkAddress, now: float) -> None: + """ + Initiate the TLS handshake. + + This method can only be called for clients and a single time. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. + + :param addr: The network address of the remote peer. + :param now: The current time. + """ + assert ( + self._is_client and not self._connect_called + ), "connect() can only be called for clients and a single time" + self._connect_called = True + + self._network_paths = [QuicNetworkPath(addr, is_validated=True)] + self._version = self._configuration.supported_versions[0] + self._connect(now=now) + + def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: + """ + Return a list of `(data, addr)` tuples of datagrams which need to be + sent, and the network address to which they need to be sent. + + After calling this method call :meth:`get_timer` to know when the next + timer needs to be set. + + :param now: The current time. + """ + network_path = self._network_paths[0] + + if self._state in END_STATES: + return [] + + # build datagrams + builder = QuicPacketBuilder( + host_cid=self.host_cid, + is_client=self._is_client, + packet_number=self._packet_number, + peer_cid=self._peer_cid, + peer_token=self._peer_token, + quic_logger=self._quic_logger, + spin_bit=self._spin_bit, + version=self._version, + ) + if self._close_pending: + for epoch, packet_type in ( + (tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT), + (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), + (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), + ): + crypto = self._cryptos[epoch] + if crypto.send.is_valid(): + builder.start_packet(packet_type, crypto) + self._write_connection_close_frame( + builder=builder, + error_code=self._close_event.error_code, + frame_type=self._close_event.frame_type, + reason_phrase=self._close_event.reason_phrase, + ) + self._close_pending = False + break + self._close_begin(is_initiator=True, now=now) + else: + # congestion control + builder.max_flight_bytes = ( + self._loss.congestion_window - self._loss.bytes_in_flight + ) + if self._probe_pending and builder.max_flight_bytes < PACKET_MAX_SIZE: + builder.max_flight_bytes = PACKET_MAX_SIZE + + # limit data on un-validated network paths + if not network_path.is_validated: + builder.max_total_bytes = ( + network_path.bytes_received * 3 - network_path.bytes_sent + ) + + try: + if not self._handshake_confirmed: + for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: + self._write_handshake(builder, epoch, now) + self._write_application(builder, network_path, now) + except QuicPacketBuilderStop: + pass + + datagrams, packets = builder.flush() + + if datagrams: + self._packet_number = builder.packet_number + + # register packets + sent_handshake = False + for packet in packets: + packet.sent_time = now + self._loss.on_packet_sent( + packet=packet, space=self._spaces[packet.epoch] + ) + if packet.epoch == tls.Epoch.HANDSHAKE: + sent_handshake = True + + # log packet + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_sent", + data={ + "packet_type": self._quic_logger.packet_type( + packet.packet_type + ), + "header": { + "packet_number": str(packet.packet_number), + "packet_size": packet.sent_bytes, + "scid": dump_cid(self.host_cid) + if is_long_header(packet.packet_type) + else "", + "dcid": dump_cid(self._peer_cid), + }, + "frames": packet.quic_logger_frames, + }, + ) + + # check if we can discard initial keys + if sent_handshake and self._is_client: + self._discard_epoch(tls.Epoch.INITIAL) + + # return datagrams to send and the destination network address + ret = [] + for datagram in datagrams: + byte_length = len(datagram) + network_path.bytes_sent += byte_length + ret.append((datagram, network_path.addr)) + + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="datagrams_sent", + data={"byte_length": byte_length, "count": 1}, + ) + return ret + + def get_next_available_stream_id(self, is_unidirectional=False) -> int: + """ + Return the stream ID for the next stream created by this endpoint. + """ + stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) + while stream_id in self._streams: + stream_id += 4 + return stream_id + + def get_timer(self) -> Optional[float]: + """ + Return the time at which the timer should fire or None if no timer is needed. + """ + timer_at = self._close_at + if self._state not in END_STATES: + # ack timer + for space in self._loss.spaces: + if space.ack_at is not None and space.ack_at < timer_at: + timer_at = space.ack_at + + # loss detection timer + self._loss_at = self._loss.get_loss_detection_time() + if self._loss_at is not None and self._loss_at < timer_at: + timer_at = self._loss_at + + # pacing timer + if self._pacing_at is not None and self._pacing_at < timer_at: + timer_at = self._pacing_at + + return timer_at + + def handle_timer(self, now: float) -> None: + """ + Handle the timer. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. + + :param now: The current time. + """ + # end of closing period or idle timeout + if now >= self._close_at: + if self._close_event is None: + self._close_event = events.ConnectionTerminated( + error_code=QuicErrorCode.INTERNAL_ERROR, + frame_type=None, + reason_phrase="Idle timeout", + ) + self._close_end() + return + + # loss detection timeout + if self._loss_at is not None and now >= self._loss_at: + self._logger.debug("Loss detection triggered") + self._loss.on_loss_detection_timeout(now=now) + + def next_event(self) -> Optional[events.QuicEvent]: + """ + Retrieve the next event from the event buffer. + + Returns `None` if there are no buffered events. + """ + try: + return self._events.popleft() + except IndexError: + return None + + def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None: + """ + Handle an incoming datagram. + + After calling this method call :meth:`datagrams_to_send` to retrieve data + which needs to be sent. + + :param data: The datagram which was received. + :param addr: The network address from which the datagram was received. + :param now: The current time. + """ + # stop handling packets when closing + if self._state in END_STATES: + return + + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="datagrams_received", + data={"byte_length": len(data), "count": 1}, + ) + + buf = Buffer(data=data) + while not buf.eof(): + start_off = buf.tell() + try: + header = pull_quic_header( + buf, host_cid_length=self._configuration.connection_id_length + ) + except ValueError: + return + + # check destination CID matches + destination_cid_seq: Optional[int] = None + for connection_id in self._host_cids: + if header.destination_cid == connection_id.cid: + destination_cid_seq = connection_id.sequence_number + break + if self._is_client and destination_cid_seq is None: + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={"trigger": "unknown_connection_id"}, + ) + return + + # check protocol version + if ( + self._is_client + and self._state == QuicConnectionState.FIRSTFLIGHT + and header.version == QuicProtocolVersion.NEGOTIATION + ): + # version negotiation + versions = [] + while not buf.eof(): + versions.append(buf.pull_uint32()) + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_received", + data={ + "packet_type": "version_negotiation", + "header": { + "scid": dump_cid(header.source_cid), + "dcid": dump_cid(header.destination_cid), + }, + "frames": [], + }, + ) + common = set(self._configuration.supported_versions).intersection( + versions + ) + if not common: + self._logger.error("Could not find a common protocol version") + self._close_event = events.ConnectionTerminated( + error_code=QuicErrorCode.INTERNAL_ERROR, + frame_type=None, + reason_phrase="Could not find a common protocol version", + ) + self._close_end() + return + self._version = QuicProtocolVersion(max(common)) + self._logger.info("Retrying with %s", self._version) + self._connect(now=now) + return + elif ( + header.version is not None + and header.version not in self._configuration.supported_versions + ): + # unsupported version + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={"trigger": "unsupported_version"}, + ) + return + + if self._is_client and header.packet_type == PACKET_TYPE_RETRY: + # calculate stateless retry integrity tag + integrity_tag = get_retry_integrity_tag( + buf.data_slice(start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE), + self._peer_cid, + ) + + if ( + header.destination_cid == self.host_cid + and header.integrity_tag == integrity_tag + and not self._stateless_retry_count + ): + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_received", + data={ + "packet_type": "retry", + "header": { + "scid": dump_cid(header.source_cid), + "dcid": dump_cid(header.destination_cid), + }, + "frames": [], + }, + ) + + self._original_connection_id = self._peer_cid + self._peer_cid = header.source_cid + self._peer_token = header.token + self._stateless_retry_count += 1 + self._logger.info("Performing stateless retry") + self._connect(now=now) + return + + network_path = self._find_network_path(addr) + + # server initialization + if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT: + assert ( + header.packet_type == PACKET_TYPE_INITIAL + ), "first packet must be INITIAL" + self._network_paths = [network_path] + self._version = QuicProtocolVersion(header.version) + self._initialize(header.destination_cid) + + # determine crypto and packet space + epoch = get_epoch(header.packet_type) + crypto = self._cryptos[epoch] + if epoch == tls.Epoch.ZERO_RTT: + space = self._spaces[tls.Epoch.ONE_RTT] + else: + space = self._spaces[epoch] + + # decrypt packet + encrypted_off = buf.tell() - start_off + end_off = buf.tell() + header.rest_length + buf.seek(end_off) + + try: + plain_header, plain_payload, packet_number = crypto.decrypt_packet( + data[start_off:end_off], encrypted_off, space.expected_packet_number + ) + except KeyUnavailableError as exc: + self._logger.debug(exc) + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={"trigger": "key_unavailable"}, + ) + continue + except CryptoError as exc: + self._logger.debug(exc) + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="packet_dropped", + data={"trigger": "payload_decrypt_error"}, + ) + continue + + # check reserved bits + if header.is_long_header: + reserved_mask = 0x0C + else: + reserved_mask = 0x18 + if plain_header[0] & reserved_mask: + self.close( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=None, + reason_phrase="Reserved bits must be zero", + ) + return + + # raise expected packet number + if packet_number > space.expected_packet_number: + space.expected_packet_number = packet_number + 1 + + # log packet + quic_logger_frames: Optional[List[Dict]] = None + if self._quic_logger is not None: + quic_logger_frames = [] + self._quic_logger.log_event( + category="transport", + event="packet_received", + data={ + "packet_type": self._quic_logger.packet_type( + header.packet_type + ), + "header": { + "packet_number": str(packet_number), + "packet_size": end_off - start_off, + "dcid": dump_cid(header.destination_cid), + "scid": dump_cid(header.source_cid), + }, + "frames": quic_logger_frames, + }, + ) + + # discard initial keys and packet space + if not self._is_client and epoch == tls.Epoch.HANDSHAKE: + self._discard_epoch(tls.Epoch.INITIAL) + + # update state + if self._peer_cid_seq is None: + self._peer_cid = header.source_cid + self._peer_cid_seq = 0 + + if self._state == QuicConnectionState.FIRSTFLIGHT: + self._set_state(QuicConnectionState.CONNECTED) + + # update spin bit + if not header.is_long_header and packet_number > self._spin_highest_pn: + spin_bit = get_spin_bit(plain_header[0]) + if self._is_client: + self._spin_bit = not spin_bit + else: + self._spin_bit = spin_bit + self._spin_highest_pn = packet_number + + if self._quic_logger is not None: + self._quic_logger.log_event( + category="connectivity", + event="spin_bit_updated", + data={"state": self._spin_bit}, + ) + + # handle payload + context = QuicReceiveContext( + epoch=epoch, + host_cid=header.destination_cid, + network_path=network_path, + quic_logger_frames=quic_logger_frames, + time=now, + ) + try: + is_ack_eliciting, is_probing = self._payload_received( + context, plain_payload + ) + except QuicConnectionError as exc: + self._logger.warning(exc) + self.close( + error_code=exc.error_code, + frame_type=exc.frame_type, + reason_phrase=exc.reason_phrase, + ) + if self._state in END_STATES or self._close_pending: + return + + # update idle timeout + self._close_at = now + self._configuration.idle_timeout + + # handle migration + if ( + not self._is_client + and context.host_cid != self.host_cid + and epoch == tls.Epoch.ONE_RTT + ): + self._logger.debug( + "Peer switching to CID %s (%d)", + dump_cid(context.host_cid), + destination_cid_seq, + ) + self.host_cid = context.host_cid + self.change_connection_id() + + # update network path + if not network_path.is_validated and epoch == tls.Epoch.HANDSHAKE: + self._logger.debug( + "Network path %s validated by handshake", network_path.addr + ) + network_path.is_validated = True + network_path.bytes_received += end_off - start_off + if network_path not in self._network_paths: + self._network_paths.append(network_path) + idx = self._network_paths.index(network_path) + if idx and not is_probing and packet_number > space.largest_received_packet: + self._logger.debug("Network path %s promoted", network_path.addr) + self._network_paths.pop(idx) + self._network_paths.insert(0, network_path) + + # record packet as received + if not space.discarded: + if packet_number > space.largest_received_packet: + space.largest_received_packet = packet_number + space.largest_received_time = now + space.ack_queue.add(packet_number) + if is_ack_eliciting and space.ack_at is None: + space.ack_at = now + self._ack_delay + + def request_key_update(self) -> None: + """ + Request an update of the encryption keys. + """ + assert self._handshake_complete, "cannot change key before handshake completes" + self._cryptos[tls.Epoch.ONE_RTT].update_key() + + def send_ping(self, uid: int) -> None: + """ + Send a PING frame to the peer. + + :param uid: A unique ID for this PING. + """ + self._ping_pending.append(uid) + + def send_datagram_frame(self, data: bytes) -> None: + """ + Send a DATAGRAM frame. + + :param data: The data to be sent. + """ + self._datagrams_pending.append(data) + + def send_stream_data( + self, stream_id: int, data: bytes, end_stream: bool = False + ) -> None: + """ + Send data on the specific stream. + + :param stream_id: The stream's ID. + :param data: The data to be sent. + :param end_stream: If set to `True`, the FIN bit will be set. + """ + if stream_is_client_initiated(stream_id) != self._is_client: + if stream_id not in self._streams: + raise ValueError("Cannot send data on unknown peer-initiated stream") + if stream_is_unidirectional(stream_id): + raise ValueError( + "Cannot send data on peer-initiated unidirectional stream" + ) + + try: + stream = self._streams[stream_id] + except KeyError: + self._create_stream(stream_id=stream_id) + stream = self._streams[stream_id] + stream.write(data, end_stream=end_stream) + + # Private + + def _alpn_handler(self, alpn_protocol: str) -> None: + """ + Callback which is invoked by the TLS engine when ALPN negotiation completes. + """ + self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) + + def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None: + """ + Check the specified stream can receive data or raises a QuicConnectionError. + """ + if not self._stream_can_receive(stream_id): + raise QuicConnectionError( + error_code=QuicErrorCode.STREAM_STATE_ERROR, + frame_type=frame_type, + reason_phrase="Stream is send-only", + ) + + def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None: + """ + Check the specified stream can send data or raises a QuicConnectionError. + """ + if not self._stream_can_send(stream_id): + raise QuicConnectionError( + error_code=QuicErrorCode.STREAM_STATE_ERROR, + frame_type=frame_type, + reason_phrase="Stream is receive-only", + ) + + def _close_begin(self, is_initiator: bool, now: float) -> None: + """ + Begin the close procedure. + """ + self._close_at = now + 3 * self._loss.get_probe_timeout() + if is_initiator: + self._set_state(QuicConnectionState.CLOSING) + else: + self._set_state(QuicConnectionState.DRAINING) + + def _close_end(self) -> None: + """ + End the close procedure. + """ + self._close_at = None + for epoch in self._spaces.keys(): + self._discard_epoch(epoch) + self._events.append(self._close_event) + self._set_state(QuicConnectionState.TERMINATED) + + # signal log end + if self._quic_logger is not None: + self._configuration.quic_logger.end_trace(self._quic_logger) + self._quic_logger = None + + def _connect(self, now: float) -> None: + """ + Start the client handshake. + """ + assert self._is_client + + self._close_at = now + self._configuration.idle_timeout + self._initialize(self._peer_cid) + + self.tls.handle_message(b"", self._crypto_buffers) + self._push_crypto_data() + + def _create_stream(self, stream_id: int) -> QuicStream: + """ + Create a QUIC stream in order to send data to the peer. + """ + # determine limits + if stream_is_unidirectional(stream_id): + max_stream_data_local = 0 + max_stream_data_remote = self._remote_max_stream_data_uni + max_streams = self._remote_max_streams_uni + streams_blocked = self._streams_blocked_uni + else: + max_stream_data_local = self._local_max_stream_data_bidi_local + max_stream_data_remote = self._remote_max_stream_data_bidi_remote + max_streams = self._remote_max_streams_bidi + streams_blocked = self._streams_blocked_bidi + + # create stream + stream = self._streams[stream_id] = QuicStream( + stream_id=stream_id, + max_stream_data_local=max_stream_data_local, + max_stream_data_remote=max_stream_data_remote, + ) + + # mark stream as blocked if needed + if stream_id // 4 >= max_streams: + stream.is_blocked = True + streams_blocked.append(stream) + self._streams_blocked_pending = True + + return stream + + def _discard_epoch(self, epoch: tls.Epoch) -> None: + self._logger.debug("Discarding epoch %s", epoch) + self._cryptos[epoch].teardown() + self._loss.discard_space(self._spaces[epoch]) + self._spaces[epoch].discarded = True + + def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath: + # check existing network paths + for idx, network_path in enumerate(self._network_paths): + if network_path.addr == addr: + return network_path + + # new network path + network_path = QuicNetworkPath(addr) + self._logger.debug("Network path %s discovered", network_path.addr) + return network_path + + def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream: + """ + Get or create a stream in response to a received frame. + """ + stream = self._streams.get(stream_id, None) + if stream is None: + # check initiator + if stream_is_client_initiated(stream_id) == self._is_client: + raise QuicConnectionError( + error_code=QuicErrorCode.STREAM_STATE_ERROR, + frame_type=frame_type, + reason_phrase="Wrong stream initiator", + ) + + # determine limits + if stream_is_unidirectional(stream_id): + max_stream_data_local = self._local_max_stream_data_uni + max_stream_data_remote = 0 + max_streams = self._local_max_streams_uni + else: + max_stream_data_local = self._local_max_stream_data_bidi_remote + max_stream_data_remote = self._remote_max_stream_data_bidi_local + max_streams = self._local_max_streams_bidi + + # check max streams + if stream_id // 4 >= max_streams: + raise QuicConnectionError( + error_code=QuicErrorCode.STREAM_LIMIT_ERROR, + frame_type=frame_type, + reason_phrase="Too many streams open", + ) + + # create stream + self._logger.debug("Stream %d created by peer" % stream_id) + stream = self._streams[stream_id] = QuicStream( + stream_id=stream_id, + max_stream_data_local=max_stream_data_local, + max_stream_data_remote=max_stream_data_remote, + ) + return stream + + def _handle_session_ticket(self, session_ticket: tls.SessionTicket) -> None: + if ( + session_ticket.max_early_data_size is not None + and session_ticket.max_early_data_size != MAX_EARLY_DATA + ): + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=QuicFrameType.CRYPTO, + reason_phrase="Invalid max_early_data value %s" + % session_ticket.max_early_data_size, + ) + self._session_ticket_handler(session_ticket) + + def _initialize(self, peer_cid: bytes) -> None: + # TLS + self.tls = tls.Context( + alpn_protocols=self._configuration.alpn_protocols, + cadata=self._configuration.cadata, + cafile=self._configuration.cafile, + capath=self._configuration.capath, + is_client=self._is_client, + logger=self._logger, + max_early_data=None if self._is_client else MAX_EARLY_DATA, + server_name=self._configuration.server_name, + verify_mode=self._configuration.verify_mode, + ) + self.tls.certificate = self._configuration.certificate + self.tls.certificate_chain = self._configuration.certificate_chain + self.tls.certificate_private_key = self._configuration.private_key + self.tls.handshake_extensions = [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + self._serialize_transport_parameters(), + ) + ] + + # TLS session resumption + session_ticket = self._configuration.session_ticket + if ( + self._is_client + and session_ticket is not None + and session_ticket.is_valid + and session_ticket.server_name == self._configuration.server_name + ): + self.tls.session_ticket = self._configuration.session_ticket + + # parse saved QUIC transport parameters - for 0-RTT + if session_ticket.max_early_data_size == MAX_EARLY_DATA: + for ext_type, ext_data in session_ticket.other_extensions: + if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: + self._parse_transport_parameters( + ext_data, from_session_ticket=True + ) + break + + # TLS callbacks + self.tls.alpn_cb = self._alpn_handler + if self._session_ticket_fetcher is not None: + self.tls.get_session_ticket_cb = self._session_ticket_fetcher + if self._session_ticket_handler is not None: + self.tls.new_session_ticket_cb = self._handle_session_ticket + self.tls.update_traffic_key_cb = self._update_traffic_key + + # packet spaces + self._cryptos = { + tls.Epoch.INITIAL: CryptoPair(), + tls.Epoch.ZERO_RTT: CryptoPair(), + tls.Epoch.HANDSHAKE: CryptoPair(), + tls.Epoch.ONE_RTT: CryptoPair(), + } + self._crypto_buffers = { + tls.Epoch.INITIAL: Buffer(capacity=4096), + tls.Epoch.HANDSHAKE: Buffer(capacity=4096), + tls.Epoch.ONE_RTT: Buffer(capacity=4096), + } + self._crypto_streams = { + tls.Epoch.INITIAL: QuicStream(), + tls.Epoch.HANDSHAKE: QuicStream(), + tls.Epoch.ONE_RTT: QuicStream(), + } + self._spaces = { + tls.Epoch.INITIAL: QuicPacketSpace(), + tls.Epoch.HANDSHAKE: QuicPacketSpace(), + tls.Epoch.ONE_RTT: QuicPacketSpace(), + } + + self._cryptos[tls.Epoch.INITIAL].setup_initial( + cid=peer_cid, is_client=self._is_client, version=self._version + ) + + self._loss.spaces = list(self._spaces.values()) + self._packet_number = 0 + + def _handle_ack_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle an ACK frame. + """ + ack_rangeset, ack_delay_encoded = pull_ack_frame(buf) + if frame_type == QuicFrameType.ACK_ECN: + buf.pull_uint_var() + buf.pull_uint_var() + buf.pull_uint_var() + ack_delay = (ack_delay_encoded << self._remote_ack_delay_exponent) / 1000000 + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_ack_frame(ack_rangeset, ack_delay) + ) + + self._loss.on_ack_received( + space=self._spaces[context.epoch], + ack_rangeset=ack_rangeset, + ack_delay=ack_delay, + now=context.time, + ) + + def _handle_connection_close_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a CONNECTION_CLOSE frame. + """ + error_code = buf.pull_uint_var() + if frame_type == QuicFrameType.TRANSPORT_CLOSE: + frame_type = buf.pull_uint_var() + else: + frame_type = None + reason_length = buf.pull_uint_var() + try: + reason_phrase = buf.pull_bytes(reason_length).decode("utf8") + except UnicodeDecodeError: + reason_phrase = "" + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_connection_close_frame( + error_code=error_code, + frame_type=frame_type, + reason_phrase=reason_phrase, + ) + ) + + self._logger.info( + "Connection close code 0x%X, reason %s", error_code, reason_phrase + ) + self._close_event = events.ConnectionTerminated( + error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase + ) + self._close_begin(is_initiator=False, now=context.time) + + def _handle_crypto_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a CRYPTO frame. + """ + offset = buf.pull_uint_var() + length = buf.pull_uint_var() + if offset + length > UINT_VAR_MAX: + raise QuicConnectionError( + error_code=QuicErrorCode.FRAME_ENCODING_ERROR, + frame_type=frame_type, + reason_phrase="offset + length cannot exceed 2^62 - 1", + ) + frame = QuicStreamFrame(offset=offset, data=buf.pull_bytes(length)) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_crypto_frame(frame) + ) + + stream = self._crypto_streams[context.epoch] + event = stream.add_frame(frame) + if event is not None: + # pass data to TLS layer + try: + self.tls.handle_message(event.data, self._crypto_buffers) + self._push_crypto_data() + except tls.Alert as exc: + raise QuicConnectionError( + error_code=QuicErrorCode.CRYPTO_ERROR + int(exc.description), + frame_type=frame_type, + reason_phrase=str(exc), + ) + + # parse transport parameters + if ( + not self._parameters_received + and self.tls.received_extensions is not None + ): + for ext_type, ext_data in self.tls.received_extensions: + if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: + self._parse_transport_parameters(ext_data) + self._parameters_received = True + break + assert ( + self._parameters_received + ), "No QUIC transport parameters received" + + # update current epoch + if not self._handshake_complete and self.tls.state in [ + tls.State.CLIENT_POST_HANDSHAKE, + tls.State.SERVER_POST_HANDSHAKE, + ]: + self._handshake_complete = True + + # for servers, the handshake is now confirmed + if not self._is_client: + self._discard_epoch(tls.Epoch.HANDSHAKE) + self._handshake_confirmed = True + self._handshake_done_pending = True + + self._loss.is_client_without_1rtt = False + self._replenish_connection_ids() + self._events.append( + events.HandshakeCompleted( + alpn_protocol=self.tls.alpn_negotiated, + early_data_accepted=self.tls.early_data_accepted, + session_resumed=self.tls.session_resumed, + ) + ) + self._unblock_streams(is_unidirectional=False) + self._unblock_streams(is_unidirectional=True) + self._logger.info( + "ALPN negotiated protocol %s", self.tls.alpn_negotiated + ) + + def _handle_data_blocked_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a DATA_BLOCKED frame. + """ + limit = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_data_blocked_frame(limit=limit) + ) + + def _handle_datagram_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a DATAGRAM frame. + """ + start = buf.tell() + if frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH: + length = buf.pull_uint_var() + else: + length = buf.capacity - start + data = buf.pull_bytes(length) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_datagram_frame(length=length) + ) + + # check frame is allowed + if ( + self._configuration.max_datagram_frame_size is None + or buf.tell() - start >= self._configuration.max_datagram_frame_size + ): + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Unexpected DATAGRAM frame", + ) + + self._events.append(events.DatagramFrameReceived(data=data)) + + def _handle_handshake_done_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a HANDSHAKE_DONE frame. + """ + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_handshake_done_frame() + ) + + if not self._is_client: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Clients must not send HANDSHAKE_DONE frames", + ) + + #  for clients, the handshake is now confirmed + if not self._handshake_confirmed: + self._discard_epoch(tls.Epoch.HANDSHAKE) + self._handshake_confirmed = True + + def _handle_max_data_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a MAX_DATA frame. + + This adjusts the total amount of we can send to the peer. + """ + max_data = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_max_data_frame(maximum=max_data) + ) + + if max_data > self._remote_max_data: + self._logger.debug("Remote max_data raised to %d", max_data) + self._remote_max_data = max_data + + def _handle_max_stream_data_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a MAX_STREAM_DATA frame. + + This adjusts the amount of data we can send on a specific stream. + """ + stream_id = buf.pull_uint_var() + max_stream_data = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_max_stream_data_frame( + maximum=max_stream_data, stream_id=stream_id + ) + ) + + # check stream direction + self._assert_stream_can_send(frame_type, stream_id) + + stream = self._get_or_create_stream(frame_type, stream_id) + if max_stream_data > stream.max_stream_data_remote: + self._logger.debug( + "Stream %d remote max_stream_data raised to %d", + stream_id, + max_stream_data, + ) + stream.max_stream_data_remote = max_stream_data + + def _handle_max_streams_bidi_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a MAX_STREAMS_BIDI frame. + + This raises number of bidirectional streams we can initiate to the peer. + """ + max_streams = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_max_streams_frame( + is_unidirectional=False, maximum=max_streams + ) + ) + + if max_streams > self._remote_max_streams_bidi: + self._logger.debug("Remote max_streams_bidi raised to %d", max_streams) + self._remote_max_streams_bidi = max_streams + self._unblock_streams(is_unidirectional=False) + + def _handle_max_streams_uni_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a MAX_STREAMS_UNI frame. + + This raises number of unidirectional streams we can initiate to the peer. + """ + max_streams = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_max_streams_frame( + is_unidirectional=True, maximum=max_streams + ) + ) + + if max_streams > self._remote_max_streams_uni: + self._logger.debug("Remote max_streams_uni raised to %d", max_streams) + self._remote_max_streams_uni = max_streams + self._unblock_streams(is_unidirectional=True) + + def _handle_new_connection_id_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a NEW_CONNECTION_ID frame. + """ + sequence_number = buf.pull_uint_var() + retire_prior_to = buf.pull_uint_var() + length = buf.pull_uint8() + connection_id = buf.pull_bytes(length) + stateless_reset_token = buf.pull_bytes(16) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_new_connection_id_frame( + connection_id=connection_id, + retire_prior_to=retire_prior_to, + sequence_number=sequence_number, + stateless_reset_token=stateless_reset_token, + ) + ) + + self._peer_cid_available.append( + QuicConnectionId( + cid=connection_id, + sequence_number=sequence_number, + stateless_reset_token=stateless_reset_token, + ) + ) + + def _handle_new_token_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a NEW_TOKEN frame. + """ + length = buf.pull_uint_var() + token = buf.pull_bytes(length) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_new_token_frame(token=token) + ) + + if not self._is_client: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Clients must not send NEW_TOKEN frames", + ) + + def _handle_padding_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a PADDING frame. + """ + # consume padding + pos = buf.tell() + for byte in buf.data_slice(pos, buf.capacity): + if byte: + break + pos += 1 + buf.seek(pos) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append(self._quic_logger.encode_padding_frame()) + + def _handle_path_challenge_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a PATH_CHALLENGE frame. + """ + data = buf.pull_bytes(8) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_path_challenge_frame(data=data) + ) + + context.network_path.remote_challenge = data + + def _handle_path_response_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a PATH_RESPONSE frame. + """ + data = buf.pull_bytes(8) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_path_response_frame(data=data) + ) + + if data != context.network_path.local_challenge: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Response does not match challenge", + ) + self._logger.debug( + "Network path %s validated by challenge", context.network_path.addr + ) + context.network_path.is_validated = True + + def _handle_ping_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a PING frame. + """ + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) + + def _handle_reset_stream_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a RESET_STREAM frame. + """ + stream_id = buf.pull_uint_var() + error_code = buf.pull_uint_var() + final_size = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_reset_stream_frame( + error_code=error_code, final_size=final_size, stream_id=stream_id + ) + ) + + # check stream direction + self._assert_stream_can_receive(frame_type, stream_id) + + self._logger.info( + "Stream %d reset by peer (error code %d, final size %d)", + stream_id, + error_code, + final_size, + ) + # stream = self._get_or_create_stream(frame_type, stream_id) + self._events.append( + events.StreamReset(error_code=error_code, stream_id=stream_id) + ) + + def _handle_retire_connection_id_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a RETIRE_CONNECTION_ID frame. + """ + sequence_number = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_retire_connection_id_frame(sequence_number) + ) + + # find the connection ID by sequence number + for index, connection_id in enumerate(self._host_cids): + if connection_id.sequence_number == sequence_number: + if connection_id.cid == context.host_cid: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Cannot retire current connection ID", + ) + self._logger.debug( + "Peer retiring CID %s (%d)", + dump_cid(connection_id.cid), + connection_id.sequence_number, + ) + del self._host_cids[index] + self._events.append( + events.ConnectionIdRetired(connection_id=connection_id.cid) + ) + break + + # issue a new connection ID + self._replenish_connection_ids() + + def _handle_stop_sending_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a STOP_SENDING frame. + """ + stream_id = buf.pull_uint_var() + error_code = buf.pull_uint_var() # application error code + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_stop_sending_frame( + error_code=error_code, stream_id=stream_id + ) + ) + + # check stream direction + self._assert_stream_can_send(frame_type, stream_id) + + self._get_or_create_stream(frame_type, stream_id) + + def _handle_stream_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a STREAM frame. + """ + stream_id = buf.pull_uint_var() + if frame_type & 4: + offset = buf.pull_uint_var() + else: + offset = 0 + if frame_type & 2: + length = buf.pull_uint_var() + else: + length = buf.capacity - buf.tell() + if offset + length > UINT_VAR_MAX: + raise QuicConnectionError( + error_code=QuicErrorCode.FRAME_ENCODING_ERROR, + frame_type=frame_type, + reason_phrase="offset + length cannot exceed 2^62 - 1", + ) + frame = QuicStreamFrame( + offset=offset, data=buf.pull_bytes(length), fin=bool(frame_type & 1) + ) + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_stream_frame(frame, stream_id=stream_id) + ) + + # check stream direction + self._assert_stream_can_receive(frame_type, stream_id) + + # check flow-control limits + stream = self._get_or_create_stream(frame_type, stream_id) + if offset + length > stream.max_stream_data_local: + raise QuicConnectionError( + error_code=QuicErrorCode.FLOW_CONTROL_ERROR, + frame_type=frame_type, + reason_phrase="Over stream data limit", + ) + newly_received = max(0, offset + length - stream._recv_highest) + if self._local_max_data_used + newly_received > self._local_max_data: + raise QuicConnectionError( + error_code=QuicErrorCode.FLOW_CONTROL_ERROR, + frame_type=frame_type, + reason_phrase="Over connection data limit", + ) + + event = stream.add_frame(frame) + if event is not None: + self._events.append(event) + self._local_max_data_used += newly_received + + def _handle_stream_data_blocked_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a STREAM_DATA_BLOCKED frame. + """ + stream_id = buf.pull_uint_var() + limit = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_stream_data_blocked_frame( + limit=limit, stream_id=stream_id + ) + ) + + # check stream direction + self._assert_stream_can_receive(frame_type, stream_id) + + self._get_or_create_stream(frame_type, stream_id) + + def _handle_streams_blocked_frame( + self, context: QuicReceiveContext, frame_type: int, buf: Buffer + ) -> None: + """ + Handle a STREAMS_BLOCKED frame. + """ + limit = buf.pull_uint_var() + + # log frame + if self._quic_logger is not None: + context.quic_logger_frames.append( + self._quic_logger.encode_streams_blocked_frame( + is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, + limit=limit, + ) + ) + + def _on_ack_delivery( + self, delivery: QuicDeliveryState, space: QuicPacketSpace, highest_acked: int + ) -> None: + """ + Callback when an ACK frame is acknowledged or lost. + """ + if delivery == QuicDeliveryState.ACKED: + space.ack_queue.subtract(0, highest_acked + 1) + + def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None: + """ + Callback when a HANDSHAKE_DONE frame is acknowledged or lost. + """ + if delivery != QuicDeliveryState.ACKED: + self._handshake_done_pending = True + + def _on_max_data_delivery(self, delivery: QuicDeliveryState) -> None: + """ + Callback when a MAX_DATA frame is acknowledged or lost. + """ + if delivery != QuicDeliveryState.ACKED: + self._local_max_data_sent = 0 + + def _on_max_stream_data_delivery( + self, delivery: QuicDeliveryState, stream: QuicStream + ) -> None: + """ + Callback when a MAX_STREAM_DATA frame is acknowledged or lost. + """ + if delivery != QuicDeliveryState.ACKED: + stream.max_stream_data_local_sent = 0 + + def _on_new_connection_id_delivery( + self, delivery: QuicDeliveryState, connection_id: QuicConnectionId + ) -> None: + """ + Callback when a NEW_CONNECTION_ID frame is acknowledged or lost. + """ + if delivery != QuicDeliveryState.ACKED: + connection_id.was_sent = False + + def _on_ping_delivery( + self, delivery: QuicDeliveryState, uids: Sequence[int] + ) -> None: + """ + Callback when a PING frame is acknowledged or lost. + """ + if delivery == QuicDeliveryState.ACKED: + self._logger.debug("Received PING%s response", "" if uids else " (probe)") + for uid in uids: + self._events.append(events.PingAcknowledged(uid=uid)) + else: + self._ping_pending.extend(uids) + + def _on_retire_connection_id_delivery( + self, delivery: QuicDeliveryState, sequence_number: int + ) -> None: + """ + Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost. + """ + if delivery != QuicDeliveryState.ACKED: + self._retire_connection_ids.append(sequence_number) + + def _payload_received( + self, context: QuicReceiveContext, plain: bytes + ) -> Tuple[bool, bool]: + """ + Handle a QUIC packet payload. + """ + buf = Buffer(data=plain) + + is_ack_eliciting = False + is_probing = None + while not buf.eof(): + frame_type = buf.pull_uint_var() + + # check frame type is known + try: + frame_handler, frame_epochs = self.__frame_handlers[frame_type] + except KeyError: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Unknown frame type", + ) + + # check frame is allowed for the epoch + if context.epoch not in frame_epochs: + raise QuicConnectionError( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=frame_type, + reason_phrase="Unexpected frame type", + ) + + # handle the frame + try: + frame_handler(context, frame_type, buf) + except BufferReadError: + raise QuicConnectionError( + error_code=QuicErrorCode.FRAME_ENCODING_ERROR, + frame_type=frame_type, + reason_phrase="Failed to parse frame", + ) + + # update ACK only / probing flags + if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: + is_ack_eliciting = True + + if frame_type not in PROBING_FRAME_TYPES: + is_probing = False + elif is_probing is None: + is_probing = True + + return is_ack_eliciting, bool(is_probing) + + def _replenish_connection_ids(self) -> None: + """ + Generate new connection IDs. + """ + while len(self._host_cids) < min(8, self._remote_active_connection_id_limit): + self._host_cids.append( + QuicConnectionId( + cid=os.urandom(self._configuration.connection_id_length), + sequence_number=self._host_cid_seq, + stateless_reset_token=os.urandom(16), + ) + ) + self._host_cid_seq += 1 + + def _push_crypto_data(self) -> None: + for epoch, buf in self._crypto_buffers.items(): + self._crypto_streams[epoch].write(buf.data) + buf.seek(0) + + def _send_probe(self) -> None: + self._probe_pending = True + + def _parse_transport_parameters( + self, data: bytes, from_session_ticket: bool = False + ) -> None: + quic_transport_parameters = pull_quic_transport_parameters( + Buffer(data=data), protocol_version=self._version + ) + + # log event + if self._quic_logger is not None and not from_session_ticket: + self._quic_logger.log_event( + category="transport", + event="parameters_set", + data=self._quic_logger.encode_transport_parameters( + owner="remote", parameters=quic_transport_parameters + ), + ) + + # validate remote parameters + if ( + self._is_client + and not from_session_ticket + and ( + quic_transport_parameters.original_connection_id + != self._original_connection_id + ) + ): + raise QuicConnectionError( + error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, + frame_type=QuicFrameType.CRYPTO, + reason_phrase="original_connection_id does not match", + ) + + # store remote parameters + if quic_transport_parameters.ack_delay_exponent is not None: + self._remote_ack_delay_exponent = self._remote_ack_delay_exponent + if quic_transport_parameters.active_connection_id_limit is not None: + self._remote_active_connection_id_limit = ( + quic_transport_parameters.active_connection_id_limit + ) + if quic_transport_parameters.idle_timeout is not None: + self._remote_idle_timeout = quic_transport_parameters.idle_timeout / 1000.0 + if quic_transport_parameters.max_ack_delay is not None: + self._loss.max_ack_delay = quic_transport_parameters.max_ack_delay / 1000.0 + self._remote_max_datagram_frame_size = ( + quic_transport_parameters.max_datagram_frame_size + ) + for param in [ + "max_data", + "max_stream_data_bidi_local", + "max_stream_data_bidi_remote", + "max_stream_data_uni", + "max_streams_bidi", + "max_streams_uni", + ]: + value = getattr(quic_transport_parameters, "initial_" + param) + if value is not None: + setattr(self, "_remote_" + param, value) + + def _serialize_transport_parameters(self) -> bytes: + quic_transport_parameters = QuicTransportParameters( + ack_delay_exponent=self._local_ack_delay_exponent, + active_connection_id_limit=self._local_active_connection_id_limit, + idle_timeout=int(self._configuration.idle_timeout * 1000), + initial_max_data=self._local_max_data, + initial_max_stream_data_bidi_local=self._local_max_stream_data_bidi_local, + initial_max_stream_data_bidi_remote=self._local_max_stream_data_bidi_remote, + initial_max_stream_data_uni=self._local_max_stream_data_uni, + initial_max_streams_bidi=self._local_max_streams_bidi, + initial_max_streams_uni=self._local_max_streams_uni, + max_ack_delay=25, + max_datagram_frame_size=self._configuration.max_datagram_frame_size, + quantum_readiness=b"Q" * 1200 + if self._configuration.quantum_readiness_test + else None, + ) + if not self._is_client: + quic_transport_parameters.original_connection_id = ( + self._original_connection_id + ) + + # log event + if self._quic_logger is not None: + self._quic_logger.log_event( + category="transport", + event="parameters_set", + data=self._quic_logger.encode_transport_parameters( + owner="local", parameters=quic_transport_parameters + ), + ) + + buf = Buffer(capacity=3 * PACKET_MAX_SIZE) + push_quic_transport_parameters( + buf, quic_transport_parameters, protocol_version=self._version + ) + return buf.data + + def _set_state(self, state: QuicConnectionState) -> None: + self._logger.debug("%s -> %s", self._state, state) + self._state = state + + def _stream_can_receive(self, stream_id: int) -> bool: + return stream_is_client_initiated( + stream_id + ) != self._is_client or not stream_is_unidirectional(stream_id) + + def _stream_can_send(self, stream_id: int) -> bool: + return stream_is_client_initiated( + stream_id + ) == self._is_client or not stream_is_unidirectional(stream_id) + + def _unblock_streams(self, is_unidirectional: bool) -> None: + if is_unidirectional: + max_stream_data_remote = self._remote_max_stream_data_uni + max_streams = self._remote_max_streams_uni + streams_blocked = self._streams_blocked_uni + else: + max_stream_data_remote = self._remote_max_stream_data_bidi_remote + max_streams = self._remote_max_streams_bidi + streams_blocked = self._streams_blocked_bidi + + while streams_blocked and streams_blocked[0].stream_id // 4 < max_streams: + stream = streams_blocked.pop(0) + stream.is_blocked = False + stream.max_stream_data_remote = max_stream_data_remote + + if not self._streams_blocked_bidi and not self._streams_blocked_uni: + self._streams_blocked_pending = False + + def _update_traffic_key( + self, + direction: tls.Direction, + epoch: tls.Epoch, + cipher_suite: tls.CipherSuite, + secret: bytes, + ) -> None: + """ + Callback which is invoked by the TLS engine when new traffic keys are + available. + """ + secrets_log_file = self._configuration.secrets_log_file + if secrets_log_file is not None: + label_row = self._is_client == (direction == tls.Direction.DECRYPT) + label = SECRETS_LABELS[label_row][epoch.value] + secrets_log_file.write( + "%s %s %s\n" % (label, self.tls.client_random.hex(), secret.hex()) + ) + secrets_log_file.flush() + + crypto = self._cryptos[epoch] + if direction == tls.Direction.ENCRYPT: + crypto.send.setup( + cipher_suite=cipher_suite, secret=secret, version=self._version + ) + else: + crypto.recv.setup( + cipher_suite=cipher_suite, secret=secret, version=self._version + ) + + def _write_application( + self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float + ) -> None: + crypto_stream: Optional[QuicStream] = None + if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): + crypto = self._cryptos[tls.Epoch.ONE_RTT] + crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] + packet_type = PACKET_TYPE_ONE_RTT + elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): + crypto = self._cryptos[tls.Epoch.ZERO_RTT] + packet_type = PACKET_TYPE_ZERO_RTT + else: + return + space = self._spaces[tls.Epoch.ONE_RTT] + + while True: + # apply pacing, except if we have ACKs to send + if space.ack_at is None or space.ack_at >= now: + self._pacing_at = self._loss._pacer.next_send_time(now=now) + if self._pacing_at is not None: + break + builder.start_packet(packet_type, crypto) + + if self._handshake_complete: + # ACK + if space.ack_at is not None and space.ack_at <= now: + self._write_ack_frame(builder=builder, space=space, now=now) + + # HANDSHAKE_DONE + if self._handshake_done_pending: + self._write_handshake_done_frame(builder=builder) + self._handshake_done_pending = False + + # PATH CHALLENGE + if ( + not network_path.is_validated + and network_path.local_challenge is None + ): + challenge = os.urandom(8) + self._write_path_challenge_frame( + builder=builder, challenge=challenge + ) + network_path.local_challenge = challenge + + # PATH RESPONSE + if network_path.remote_challenge is not None: + self._write_path_response_frame( + builder=builder, challenge=network_path.remote_challenge + ) + network_path.remote_challenge = None + + # NEW_CONNECTION_ID + for connection_id in self._host_cids: + if not connection_id.was_sent: + self._write_new_connection_id_frame( + builder=builder, connection_id=connection_id + ) + + # RETIRE_CONNECTION_ID + while self._retire_connection_ids: + sequence_number = self._retire_connection_ids.pop(0) + self._write_retire_connection_id_frame( + builder=builder, sequence_number=sequence_number + ) + + # STREAMS_BLOCKED + if self._streams_blocked_pending: + if self._streams_blocked_bidi: + self._write_streams_blocked_frame( + builder=builder, + frame_type=QuicFrameType.STREAMS_BLOCKED_BIDI, + limit=self._remote_max_streams_bidi, + ) + if self._streams_blocked_uni: + self._write_streams_blocked_frame( + builder=builder, + frame_type=QuicFrameType.STREAMS_BLOCKED_UNI, + limit=self._remote_max_streams_uni, + ) + self._streams_blocked_pending = False + + # MAX_DATA + self._write_connection_limits(builder=builder, space=space) + + # stream-level limits + for stream in self._streams.values(): + self._write_stream_limits(builder=builder, space=space, stream=stream) + + # PING (user-request) + if self._ping_pending: + self._write_ping_frame(builder, self._ping_pending) + self._ping_pending.clear() + + # PING (probe) + if self._probe_pending: + self._write_ping_frame(builder, comment="probe") + self._probe_pending = False + + # CRYPTO + if crypto_stream is not None and not crypto_stream.send_buffer_is_empty: + self._write_crypto_frame( + builder=builder, space=space, stream=crypto_stream + ) + + # DATAGRAM + while self._datagrams_pending: + try: + self._write_datagram_frame( + builder=builder, + data=self._datagrams_pending[0], + frame_type=QuicFrameType.DATAGRAM_WITH_LENGTH, + ) + self._datagrams_pending.popleft() + except QuicPacketBuilderStop: + break + + # STREAM + for stream in self._streams.values(): + if not stream.is_blocked and not stream.send_buffer_is_empty: + self._remote_max_data_used += self._write_stream_frame( + builder=builder, + space=space, + stream=stream, + max_offset=min( + stream._send_highest + + self._remote_max_data + - self._remote_max_data_used, + stream.max_stream_data_remote, + ), + ) + + if builder.packet_is_empty: + break + else: + self._loss._pacer.update_after_send(now=now) + + def _write_handshake( + self, builder: QuicPacketBuilder, epoch: tls.Epoch, now: float + ) -> None: + crypto = self._cryptos[epoch] + if not crypto.send.is_valid(): + return + + crypto_stream = self._crypto_streams[epoch] + space = self._spaces[epoch] + + while True: + if epoch == tls.Epoch.INITIAL: + packet_type = PACKET_TYPE_INITIAL + else: + packet_type = PACKET_TYPE_HANDSHAKE + builder.start_packet(packet_type, crypto) + + # ACK + if space.ack_at is not None: + self._write_ack_frame(builder=builder, space=space, now=now) + + # CRYPTO + if not crypto_stream.send_buffer_is_empty: + if self._write_crypto_frame( + builder=builder, space=space, stream=crypto_stream + ): + self._probe_pending = False + + # PING (probe) + if ( + self._probe_pending + and epoch == tls.Epoch.HANDSHAKE + and not self._handshake_complete + ): + self._write_ping_frame(builder, comment="probe") + self._probe_pending = False + + if builder.packet_is_empty: + break + + def _write_ack_frame( + self, builder: QuicPacketBuilder, space: QuicPacketSpace, now: float + ) -> None: + # calculate ACK delay + ack_delay = now - space.largest_received_time + ack_delay_encoded = int(ack_delay * 1000000) >> self._local_ack_delay_exponent + + buf = builder.start_frame( + QuicFrameType.ACK, + capacity=ACK_FRAME_CAPACITY, + handler=self._on_ack_delivery, + handler_args=(space, space.largest_received_packet), + ) + ranges = push_ack_frame(buf, space.ack_queue, ack_delay_encoded) + space.ack_at = None + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_ack_frame( + ranges=space.ack_queue, delay=ack_delay + ) + ) + + # check if we need to trigger an ACK-of-ACK + if ranges > 1 and builder.packet_number % 8 == 0: + self._write_ping_frame(builder, comment="ACK-of-ACK trigger") + + def _write_connection_close_frame( + self, + builder: QuicPacketBuilder, + error_code: int, + frame_type: Optional[int], + reason_phrase: str, + ) -> None: + reason_bytes = reason_phrase.encode("utf8") + reason_length = len(reason_bytes) + + if frame_type is None: + buf = builder.start_frame( + QuicFrameType.APPLICATION_CLOSE, + capacity=APPLICATION_CLOSE_FRAME_CAPACITY + reason_length, + ) + buf.push_uint_var(error_code) + buf.push_uint_var(reason_length) + buf.push_bytes(reason_bytes) + else: + buf = builder.start_frame( + QuicFrameType.TRANSPORT_CLOSE, + capacity=TRANSPORT_CLOSE_FRAME_CAPACITY + reason_length, + ) + buf.push_uint_var(error_code) + buf.push_uint_var(frame_type) + buf.push_uint_var(reason_length) + buf.push_bytes(reason_bytes) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_connection_close_frame( + error_code=error_code, + frame_type=frame_type, + reason_phrase=reason_phrase, + ) + ) + + def _write_connection_limits( + self, builder: QuicPacketBuilder, space: QuicPacketSpace + ) -> None: + """ + Raise MAX_DATA if needed. + """ + if self._local_max_data_used * 2 > self._local_max_data: + self._local_max_data *= 2 + self._logger.debug("Local max_data raised to %d", self._local_max_data) + if self._local_max_data_sent != self._local_max_data: + buf = builder.start_frame( + QuicFrameType.MAX_DATA, + capacity=MAX_DATA_FRAME_CAPACITY, + handler=self._on_max_data_delivery, + ) + buf.push_uint_var(self._local_max_data) + self._local_max_data_sent = self._local_max_data + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_max_data_frame(self._local_max_data) + ) + + def _write_crypto_frame( + self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream + ) -> bool: + frame_overhead = 3 + size_uint_var(stream.next_send_offset) + frame = stream.get_frame(builder.remaining_flight_space - frame_overhead) + if frame is not None: + buf = builder.start_frame( + QuicFrameType.CRYPTO, + capacity=frame_overhead, + handler=stream.on_data_delivery, + handler_args=(frame.offset, frame.offset + len(frame.data)), + ) + buf.push_uint_var(frame.offset) + buf.push_uint16(len(frame.data) | 0x4000) + buf.push_bytes(frame.data) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_crypto_frame(frame) + ) + return True + + return False + + def _write_datagram_frame( + self, builder: QuicPacketBuilder, data: bytes, frame_type: QuicFrameType + ) -> bool: + """ + Write a DATAGRAM frame. + + Returns True if the frame was processed, False otherwise. + """ + assert frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH + length = len(data) + frame_size = 1 + size_uint_var(length) + length + + buf = builder.start_frame(frame_type, capacity=frame_size) + buf.push_uint_var(length) + buf.push_bytes(data) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_datagram_frame(length=length) + ) + + return True + + def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None: + builder.start_frame( + QuicFrameType.HANDSHAKE_DONE, + capacity=HANDSHAKE_DONE_FRAME_CAPACITY, + handler=self._on_handshake_done_delivery, + ) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_handshake_done_frame() + ) + + def _write_new_connection_id_frame( + self, builder: QuicPacketBuilder, connection_id: QuicConnectionId + ) -> None: + retire_prior_to = 0 # FIXME + + buf = builder.start_frame( + QuicFrameType.NEW_CONNECTION_ID, + capacity=NEW_CONNECTION_ID_FRAME_CAPACITY, + handler=self._on_new_connection_id_delivery, + handler_args=(connection_id,), + ) + buf.push_uint_var(connection_id.sequence_number) + buf.push_uint_var(retire_prior_to) + buf.push_uint8(len(connection_id.cid)) + buf.push_bytes(connection_id.cid) + buf.push_bytes(connection_id.stateless_reset_token) + + connection_id.was_sent = True + self._events.append(events.ConnectionIdIssued(connection_id=connection_id.cid)) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_new_connection_id_frame( + connection_id=connection_id.cid, + retire_prior_to=retire_prior_to, + sequence_number=connection_id.sequence_number, + stateless_reset_token=connection_id.stateless_reset_token, + ) + ) + + def _write_path_challenge_frame( + self, builder: QuicPacketBuilder, challenge: bytes + ) -> None: + buf = builder.start_frame( + QuicFrameType.PATH_CHALLENGE, capacity=PATH_CHALLENGE_FRAME_CAPACITY + ) + buf.push_bytes(challenge) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_path_challenge_frame(data=challenge) + ) + + def _write_path_response_frame( + self, builder: QuicPacketBuilder, challenge: bytes + ) -> None: + buf = builder.start_frame( + QuicFrameType.PATH_RESPONSE, capacity=PATH_RESPONSE_FRAME_CAPACITY + ) + buf.push_bytes(challenge) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_path_response_frame(data=challenge) + ) + + def _write_ping_frame( + self, builder: QuicPacketBuilder, uids: List[int] = [], comment="" + ): + builder.start_frame( + QuicFrameType.PING, + capacity=PING_FRAME_CAPACITY, + handler=self._on_ping_delivery, + handler_args=(tuple(uids),), + ) + self._logger.debug( + "Sending PING%s in packet %d", + " (%s)" % comment if comment else "", + builder.packet_number, + ) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) + + def _write_retire_connection_id_frame( + self, builder: QuicPacketBuilder, sequence_number: int + ) -> None: + buf = builder.start_frame( + QuicFrameType.RETIRE_CONNECTION_ID, + capacity=RETIRE_CONNECTION_ID_CAPACITY, + handler=self._on_retire_connection_id_delivery, + handler_args=(sequence_number,), + ) + buf.push_uint_var(sequence_number) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_retire_connection_id_frame(sequence_number) + ) + + def _write_stream_frame( + self, + builder: QuicPacketBuilder, + space: QuicPacketSpace, + stream: QuicStream, + max_offset: int, + ) -> int: + # the frame data size is constrained by our peer's MAX_DATA and + # the space available in the current packet + frame_overhead = ( + 3 + + size_uint_var(stream.stream_id) + + (size_uint_var(stream.next_send_offset) if stream.next_send_offset else 0) + ) + previous_send_highest = stream._send_highest + frame = stream.get_frame( + builder.remaining_flight_space - frame_overhead, max_offset + ) + + if frame is not None: + frame_type = QuicFrameType.STREAM_BASE | 2 # length + if frame.offset: + frame_type |= 4 + if frame.fin: + frame_type |= 1 + buf = builder.start_frame( + frame_type, + capacity=frame_overhead, + handler=stream.on_data_delivery, + handler_args=(frame.offset, frame.offset + len(frame.data)), + ) + buf.push_uint_var(stream.stream_id) + if frame.offset: + buf.push_uint_var(frame.offset) + buf.push_uint16(len(frame.data) | 0x4000) + buf.push_bytes(frame.data) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_stream_frame( + frame, stream_id=stream.stream_id + ) + ) + + return stream._send_highest - previous_send_highest + else: + return 0 + + def _write_stream_limits( + self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream + ) -> None: + """ + Raise MAX_STREAM_DATA if needed. + + The only case where `stream.max_stream_data_local` is zero is for + locally created unidirectional streams. We skip such streams to avoid + spurious logging. + """ + if ( + stream.max_stream_data_local + and stream._recv_highest * 2 > stream.max_stream_data_local + ): + stream.max_stream_data_local *= 2 + self._logger.debug( + "Stream %d local max_stream_data raised to %d", + stream.stream_id, + stream.max_stream_data_local, + ) + if stream.max_stream_data_local_sent != stream.max_stream_data_local: + buf = builder.start_frame( + QuicFrameType.MAX_STREAM_DATA, + capacity=MAX_STREAM_DATA_FRAME_CAPACITY, + handler=self._on_max_stream_data_delivery, + handler_args=(stream,), + ) + buf.push_uint_var(stream.stream_id) + buf.push_uint_var(stream.max_stream_data_local) + stream.max_stream_data_local_sent = stream.max_stream_data_local + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_max_stream_data_frame( + maximum=stream.max_stream_data_local, stream_id=stream.stream_id + ) + ) + + def _write_streams_blocked_frame( + self, builder: QuicPacketBuilder, frame_type: QuicFrameType, limit: int + ) -> None: + buf = builder.start_frame(frame_type, capacity=STREAMS_BLOCKED_CAPACITY) + buf.push_uint_var(limit) + + # log frame + if self._quic_logger is not None: + builder.quic_logger_frames.append( + self._quic_logger.encode_streams_blocked_frame( + is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, + limit=limit, + ) + ) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/crypto.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/crypto.py new file mode 100644 index 000000000000..8acb30eeda8f --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/crypto.py @@ -0,0 +1,196 @@ +import binascii +from typing import Optional, Tuple + +from .._crypto import AEAD, CryptoError, HeaderProtection +from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract +from .packet import decode_packet_number, is_long_header + +CIPHER_SUITES = { + CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"), + CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"), + CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"), +} +INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256 +INITIAL_SALT = binascii.unhexlify("c3eef712c72ebb5a11a7d2432bb46365bef9f502") +SAMPLE_SIZE = 16 + + +class KeyUnavailableError(CryptoError): + pass + + +def derive_key_iv_hp( + cipher_suite: CipherSuite, secret: bytes +) -> Tuple[bytes, bytes, bytes]: + algorithm = cipher_suite_hash(cipher_suite) + if cipher_suite in [ + CipherSuite.AES_256_GCM_SHA384, + CipherSuite.CHACHA20_POLY1305_SHA256, + ]: + key_size = 32 + else: + key_size = 16 + return ( + hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size), + hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12), + hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size), + ) + + +class CryptoContext: + def __init__(self, key_phase: int = 0) -> None: + self.aead: Optional[AEAD] = None + self.cipher_suite: Optional[CipherSuite] = None + self.hp: Optional[HeaderProtection] = None + self.key_phase = key_phase + self.secret: Optional[bytes] = None + self.version: Optional[int] = None + + def decrypt_packet( + self, packet: bytes, encrypted_offset: int, expected_packet_number: int + ) -> Tuple[bytes, bytes, int, bool]: + if self.aead is None: + raise KeyUnavailableError("Decryption key is not available") + + # header protection + plain_header, packet_number = self.hp.remove(packet, encrypted_offset) + first_byte = plain_header[0] + + # packet number + pn_length = (first_byte & 0x03) + 1 + packet_number = decode_packet_number( + packet_number, pn_length * 8, expected_packet_number + ) + + # detect key phase change + crypto = self + if not is_long_header(first_byte): + key_phase = (first_byte & 4) >> 2 + if key_phase != self.key_phase: + crypto = next_key_phase(self) + + # payload protection + payload = crypto.aead.decrypt( + packet[len(plain_header) :], plain_header, packet_number + ) + + return plain_header, payload, packet_number, crypto != self + + def encrypt_packet( + self, plain_header: bytes, plain_payload: bytes, packet_number: int + ) -> bytes: + assert self.is_valid(), "Encryption key is not available" + + # payload protection + protected_payload = self.aead.encrypt( + plain_payload, plain_header, packet_number + ) + + # header protection + return self.hp.apply(plain_header, protected_payload) + + def is_valid(self) -> bool: + return self.aead is not None + + def setup(self, cipher_suite: CipherSuite, secret: bytes, version: int) -> None: + hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite] + + key, iv, hp = derive_key_iv_hp(cipher_suite, secret) + self.aead = AEAD(aead_cipher_name, key, iv) + self.cipher_suite = cipher_suite + self.hp = HeaderProtection(hp_cipher_name, hp) + self.secret = secret + self.version = version + + def teardown(self) -> None: + self.aead = None + self.cipher_suite = None + self.hp = None + self.secret = None + + +def apply_key_phase(self: CryptoContext, crypto: CryptoContext) -> None: + self.aead = crypto.aead + self.key_phase = crypto.key_phase + self.secret = crypto.secret + + +def next_key_phase(self: CryptoContext) -> CryptoContext: + algorithm = cipher_suite_hash(self.cipher_suite) + + crypto = CryptoContext(key_phase=int(not self.key_phase)) + crypto.setup( + cipher_suite=self.cipher_suite, + secret=hkdf_expand_label( + algorithm, self.secret, b"quic ku", b"", algorithm.digest_size + ), + version=self.version, + ) + return crypto + + +class CryptoPair: + def __init__(self) -> None: + self.aead_tag_size = 16 + self.recv = CryptoContext() + self.send = CryptoContext() + self._update_key_requested = False + + def decrypt_packet( + self, packet: bytes, encrypted_offset: int, expected_packet_number: int + ) -> Tuple[bytes, bytes, int]: + plain_header, payload, packet_number, update_key = self.recv.decrypt_packet( + packet, encrypted_offset, expected_packet_number + ) + if update_key: + self._update_key() + return plain_header, payload, packet_number + + def encrypt_packet( + self, plain_header: bytes, plain_payload: bytes, packet_number: int + ) -> bytes: + if self._update_key_requested: + self._update_key() + return self.send.encrypt_packet(plain_header, plain_payload, packet_number) + + def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None: + if is_client: + recv_label, send_label = b"server in", b"client in" + else: + recv_label, send_label = b"client in", b"server in" + + algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE) + initial_secret = hkdf_extract(algorithm, INITIAL_SALT, cid) + self.recv.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=hkdf_expand_label( + algorithm, initial_secret, recv_label, b"", algorithm.digest_size + ), + version=version, + ) + self.send.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=hkdf_expand_label( + algorithm, initial_secret, send_label, b"", algorithm.digest_size + ), + version=version, + ) + + def teardown(self) -> None: + self.recv.teardown() + self.send.teardown() + + def update_key(self) -> None: + self._update_key_requested = True + + @property + def key_phase(self) -> int: + if self._update_key_requested: + return int(not self.recv.key_phase) + else: + return self.recv.key_phase + + def _update_key(self) -> None: + apply_key_phase(self.recv, next_key_phase(self.recv)) + apply_key_phase(self.send, next_key_phase(self.send)) + self._update_key_requested = False diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/events.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/events.py new file mode 100644 index 000000000000..4430693936c1 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/events.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +from typing import Optional + + +class QuicEvent: + """ + Base class for QUIC events. + """ + + pass + + +@dataclass +class ConnectionIdIssued(QuicEvent): + connection_id: bytes + + +@dataclass +class ConnectionIdRetired(QuicEvent): + connection_id: bytes + + +@dataclass +class ConnectionTerminated(QuicEvent): + """ + The ConnectionTerminated event is fired when the QUIC connection is terminated. + """ + + error_code: int + "The error code which was specified when closing the connection." + + frame_type: Optional[int] + "The frame type which caused the connection to be closed, or `None`." + + reason_phrase: str + "The human-readable reason for which the connection was closed." + + +@dataclass +class DatagramFrameReceived(QuicEvent): + """ + The DatagramFrameReceived event is fired when a DATAGRAM frame is received. + """ + + data: bytes + "The data which was received." + + +@dataclass +class HandshakeCompleted(QuicEvent): + """ + The HandshakeCompleted event is fired when the TLS handshake completes. + """ + + alpn_protocol: Optional[str] + "The protocol which was negotiated using ALPN, or `None`." + + early_data_accepted: bool + "Whether early (0-RTT) data was accepted by the remote peer." + + session_resumed: bool + "Whether a TLS session was resumed." + + +@dataclass +class PingAcknowledged(QuicEvent): + """ + The PingAcknowledged event is fired when a PING frame is acknowledged. + """ + + uid: int + "The unique ID of the PING." + + +@dataclass +class ProtocolNegotiated(QuicEvent): + """ + The ProtocolNegotiated event is fired when ALPN negotiation completes. + """ + + alpn_protocol: Optional[str] + "The protocol which was negotiated using ALPN, or `None`." + + +@dataclass +class StreamDataReceived(QuicEvent): + """ + The StreamDataReceived event is fired whenever data is received on a + stream. + """ + + data: bytes + "The data which was received." + + end_stream: bool + "Whether the STREAM frame had the FIN bit set." + + stream_id: int + "The ID of the stream the data was received for." + + +@dataclass +class StreamReset(QuicEvent): + """ + The StreamReset event is fired when the remote peer resets a stream. + """ + + error_code: int + "The error code that triggered the reset." + + stream_id: int + "The ID of the stream that was reset." diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/logger.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/logger.py new file mode 100644 index 000000000000..c4975bd592e1 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/logger.py @@ -0,0 +1,266 @@ +import binascii +import time +from collections import deque +from typing import Any, Deque, Dict, List, Optional, Tuple + +from .packet import ( + PACKET_TYPE_HANDSHAKE, + PACKET_TYPE_INITIAL, + PACKET_TYPE_MASK, + PACKET_TYPE_ONE_RTT, + PACKET_TYPE_RETRY, + PACKET_TYPE_ZERO_RTT, + QuicStreamFrame, + QuicTransportParameters, +) +from .rangeset import RangeSet + +PACKET_TYPE_NAMES = { + PACKET_TYPE_INITIAL: "initial", + PACKET_TYPE_HANDSHAKE: "handshake", + PACKET_TYPE_ZERO_RTT: "0RTT", + PACKET_TYPE_ONE_RTT: "1RTT", + PACKET_TYPE_RETRY: "retry", +} + + +def hexdump(data: bytes) -> str: + return binascii.hexlify(data).decode("ascii") + + +class QuicLoggerTrace: + """ + A QUIC event trace. + + Events are logged in the format defined by qlog draft-01. + + See: https://quiclog.github.io/internet-drafts/draft-marx-qlog-event-definitions-quic-h3.html + """ + + def __init__(self, *, is_client: bool, odcid: bytes) -> None: + self._odcid = odcid + self._events: Deque[Tuple[float, str, str, Dict[str, Any]]] = deque() + self._vantage_point = { + "name": "aioquic", + "type": "client" if is_client else "server", + } + + def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict: + return { + "ack_delay": str(self.encode_time(delay)), + "acked_ranges": [[str(x.start), str(x.stop - 1)] for x in ranges], + "frame_type": "ack", + } + + def encode_connection_close_frame( + self, error_code: int, frame_type: Optional[int], reason_phrase: str + ) -> Dict: + attrs = { + "error_code": error_code, + "error_space": "application" if frame_type is None else "transport", + "frame_type": "connection_close", + "raw_error_code": error_code, + "reason": reason_phrase, + } + if frame_type is not None: + attrs["trigger_frame_type"] = frame_type + + return attrs + + def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict: + return { + "frame_type": "crypto", + "length": len(frame.data), + "offset": str(frame.offset), + } + + def encode_data_blocked_frame(self, limit: int) -> Dict: + return {"frame_type": "data_blocked", "limit": str(limit)} + + def encode_datagram_frame(self, length: int) -> Dict: + return {"frame_type": "datagram", "length": length} + + def encode_handshake_done_frame(self) -> Dict: + return {"frame_type": "handshake_done"} + + def encode_max_data_frame(self, maximum: int) -> Dict: + return {"frame_type": "max_data", "maximum": str(maximum)} + + def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict: + return { + "frame_type": "max_stream_data", + "maximum": str(maximum), + "stream_id": str(stream_id), + } + + def encode_max_streams_frame(self, is_unidirectional: bool, maximum: int) -> Dict: + return { + "frame_type": "max_streams", + "maximum": str(maximum), + "stream_type": "unidirectional" if is_unidirectional else "bidirectional", + } + + def encode_new_connection_id_frame( + self, + connection_id: bytes, + retire_prior_to: int, + sequence_number: int, + stateless_reset_token: bytes, + ) -> Dict: + return { + "connection_id": hexdump(connection_id), + "frame_type": "new_connection_id", + "length": len(connection_id), + "reset_token": hexdump(stateless_reset_token), + "retire_prior_to": str(retire_prior_to), + "sequence_number": str(sequence_number), + } + + def encode_new_token_frame(self, token: bytes) -> Dict: + return { + "frame_type": "new_token", + "length": len(token), + "token": hexdump(token), + } + + def encode_padding_frame(self) -> Dict: + return {"frame_type": "padding"} + + def encode_path_challenge_frame(self, data: bytes) -> Dict: + return {"data": hexdump(data), "frame_type": "path_challenge"} + + def encode_path_response_frame(self, data: bytes) -> Dict: + return {"data": hexdump(data), "frame_type": "path_response"} + + def encode_ping_frame(self) -> Dict: + return {"frame_type": "ping"} + + def encode_reset_stream_frame( + self, error_code: int, final_size: int, stream_id: int + ) -> Dict: + return { + "error_code": error_code, + "final_size": str(final_size), + "frame_type": "reset_stream", + "stream_id": str(stream_id), + } + + def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict: + return { + "frame_type": "retire_connection_id", + "sequence_number": str(sequence_number), + } + + def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict: + return { + "frame_type": "stream_data_blocked", + "limit": str(limit), + "stream_id": str(stream_id), + } + + def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict: + return { + "frame_type": "stop_sending", + "error_code": error_code, + "stream_id": str(stream_id), + } + + def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict: + return { + "fin": frame.fin, + "frame_type": "stream", + "length": len(frame.data), + "offset": str(frame.offset), + "stream_id": str(stream_id), + } + + def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict: + return { + "frame_type": "streams_blocked", + "limit": str(limit), + "stream_type": "unidirectional" if is_unidirectional else "bidirectional", + } + + def encode_time(self, seconds: float) -> int: + """ + Convert a time to integer microseconds. + """ + return int(seconds * 1000000) + + def encode_transport_parameters( + self, owner: str, parameters: QuicTransportParameters + ) -> Dict[str, Any]: + data: Dict[str, Any] = {"owner": owner} + for param_name, param_value in parameters.__dict__.items(): + if isinstance(param_value, bool): + data[param_name] = param_value + elif isinstance(param_value, bytes): + data[param_name] = hexdump(param_value) + elif isinstance(param_value, int): + data[param_name] = param_value + return data + + def log_event(self, *, category: str, event: str, data: Dict) -> None: + self._events.append((time.time(), category, event, data)) + + def packet_type(self, packet_type: int) -> str: + return PACKET_TYPE_NAMES.get(packet_type & PACKET_TYPE_MASK, "1RTT") + + def to_dict(self) -> Dict[str, Any]: + """ + Return the trace as a dictionary which can be written as JSON. + """ + if self._events: + reference_time = self._events[0][0] + else: + reference_time = 0.0 + return { + "configuration": {"time_units": "us"}, + "common_fields": { + "ODCID": hexdump(self._odcid), + "reference_time": str(self.encode_time(reference_time)), + }, + "event_fields": ["relative_time", "category", "event_type", "data"], + "events": list( + map( + lambda event: ( + str(self.encode_time(event[0] - reference_time)), + event[1], + event[2], + event[3], + ), + self._events, + ) + ), + "vantage_point": self._vantage_point, + } + + +class QuicLogger: + """ + A QUIC event logger. + + Serves as a container for traces in the format defined by qlog draft-01. + + See: https://quiclog.github.io/internet-drafts/draft-marx-qlog-main-schema.html + """ + + def __init__(self) -> None: + self._traces: List[QuicLoggerTrace] = [] + + def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace: + trace = QuicLoggerTrace(is_client=is_client, odcid=odcid) + self._traces.append(trace) + return trace + + def end_trace(self, trace: QuicLoggerTrace) -> None: + assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger" + + def to_dict(self) -> Dict[str, Any]: + """ + Return the traces as a dictionary which can be written as JSON. + """ + return { + "qlog_version": "draft-01", + "traces": [trace.to_dict() for trace in self._traces], + } diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet.py new file mode 100644 index 000000000000..b9944599085d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet.py @@ -0,0 +1,508 @@ +import binascii +import ipaddress +import os +from dataclasses import dataclass +from enum import IntEnum +from typing import List, Optional, Tuple + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ..buffer import Buffer +from ..tls import pull_block, push_block +from .rangeset import RangeSet + +PACKET_LONG_HEADER = 0x80 +PACKET_FIXED_BIT = 0x40 +PACKET_SPIN_BIT = 0x20 + +PACKET_TYPE_INITIAL = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x00 +PACKET_TYPE_ZERO_RTT = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x10 +PACKET_TYPE_HANDSHAKE = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x20 +PACKET_TYPE_RETRY = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x30 +PACKET_TYPE_ONE_RTT = PACKET_FIXED_BIT +PACKET_TYPE_MASK = 0xF0 + +CONNECTION_ID_MAX_SIZE = 20 +PACKET_NUMBER_MAX_SIZE = 4 +RETRY_AEAD_KEY = binascii.unhexlify("4d32ecdb2a2133c841e4043df27d4430") +RETRY_AEAD_NONCE = binascii.unhexlify("4d1611d05513a552c587d575") +RETRY_INTEGRITY_TAG_SIZE = 16 + + +class QuicErrorCode(IntEnum): + NO_ERROR = 0x0 + INTERNAL_ERROR = 0x1 + SERVER_BUSY = 0x2 + FLOW_CONTROL_ERROR = 0x3 + STREAM_LIMIT_ERROR = 0x4 + STREAM_STATE_ERROR = 0x5 + FINAL_SIZE_ERROR = 0x6 + FRAME_ENCODING_ERROR = 0x7 + TRANSPORT_PARAMETER_ERROR = 0x8 + CONNECTION_ID_LIMIT_ERROR = 0x9 + PROTOCOL_VIOLATION = 0xA + INVALID_TOKEN = 0xB + CRYPTO_BUFFER_EXCEEDED = 0xD + CRYPTO_ERROR = 0x100 + + +class QuicProtocolVersion(IntEnum): + NEGOTIATION = 0 + DRAFT_25 = 0xFF000019 + DRAFT_26 = 0xFF00001A + DRAFT_27 = 0xFF00001B + + +@dataclass +class QuicHeader: + is_long_header: bool + version: Optional[int] + packet_type: int + destination_cid: bytes + source_cid: bytes + token: bytes = b"" + integrity_tag: bytes = b"" + rest_length: int = 0 + + +def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int: + """ + Recover a packet number from a truncated packet number. + + See: Appendix A - Sample Packet Number Decoding Algorithm + """ + window = 1 << num_bits + half_window = window // 2 + candidate = (expected & ~(window - 1)) | truncated + if candidate <= expected - half_window and candidate < (1 << 62) - window: + return candidate + window + elif candidate > expected + half_window and candidate >= window: + return candidate - window + else: + return candidate + + +def get_retry_integrity_tag( + packet_without_tag: bytes, original_destination_cid: bytes +) -> bytes: + """ + Calculate the integrity tag for a RETRY packet. + """ + # build Retry pseudo packet + buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag)) + buf.push_uint8(len(original_destination_cid)) + buf.push_bytes(original_destination_cid) + buf.push_bytes(packet_without_tag) + assert buf.eof() + + # run AES-128-GCM + aead = AESGCM(RETRY_AEAD_KEY) + integrity_tag = aead.encrypt(RETRY_AEAD_NONCE, b"", buf.data) + assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE + return integrity_tag + + +def get_spin_bit(first_byte: int) -> bool: + return bool(first_byte & PACKET_SPIN_BIT) + + +def is_long_header(first_byte: int) -> bool: + return bool(first_byte & PACKET_LONG_HEADER) + + +def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader: + first_byte = buf.pull_uint8() + + integrity_tag = b"" + token = b"" + if is_long_header(first_byte): + # long header packet + version = buf.pull_uint32() + + destination_cid_length = buf.pull_uint8() + if destination_cid_length > CONNECTION_ID_MAX_SIZE: + raise ValueError( + "Destination CID is too long (%d bytes)" % destination_cid_length + ) + destination_cid = buf.pull_bytes(destination_cid_length) + + source_cid_length = buf.pull_uint8() + if source_cid_length > CONNECTION_ID_MAX_SIZE: + raise ValueError("Source CID is too long (%d bytes)" % source_cid_length) + source_cid = buf.pull_bytes(source_cid_length) + + if version == QuicProtocolVersion.NEGOTIATION: + # version negotiation + packet_type = None + rest_length = buf.capacity - buf.tell() + else: + if not (first_byte & PACKET_FIXED_BIT): + raise ValueError("Packet fixed bit is zero") + + packet_type = first_byte & PACKET_TYPE_MASK + if packet_type == PACKET_TYPE_INITIAL: + token_length = buf.pull_uint_var() + token = buf.pull_bytes(token_length) + rest_length = buf.pull_uint_var() + elif packet_type == PACKET_TYPE_RETRY: + token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE + token = buf.pull_bytes(token_length) + integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE) + rest_length = 0 + else: + rest_length = buf.pull_uint_var() + + return QuicHeader( + is_long_header=True, + version=version, + packet_type=packet_type, + destination_cid=destination_cid, + source_cid=source_cid, + token=token, + integrity_tag=integrity_tag, + rest_length=rest_length, + ) + else: + # short header packet + if not (first_byte & PACKET_FIXED_BIT): + raise ValueError("Packet fixed bit is zero") + + packet_type = first_byte & PACKET_TYPE_MASK + destination_cid = buf.pull_bytes(host_cid_length) + return QuicHeader( + is_long_header=False, + version=None, + packet_type=packet_type, + destination_cid=destination_cid, + source_cid=b"", + token=b"", + rest_length=buf.capacity - buf.tell(), + ) + + +def encode_quic_retry( + version: int, + source_cid: bytes, + destination_cid: bytes, + original_destination_cid: bytes, + retry_token: bytes, +) -> bytes: + buf = Buffer( + capacity=7 + + len(destination_cid) + + len(source_cid) + + len(retry_token) + + RETRY_INTEGRITY_TAG_SIZE + ) + buf.push_uint8(PACKET_TYPE_RETRY) + buf.push_uint32(version) + buf.push_uint8(len(destination_cid)) + buf.push_bytes(destination_cid) + buf.push_uint8(len(source_cid)) + buf.push_bytes(source_cid) + buf.push_bytes(retry_token) + buf.push_bytes(get_retry_integrity_tag(buf.data, original_destination_cid)) + assert buf.eof() + return buf.data + + +def encode_quic_version_negotiation( + source_cid: bytes, destination_cid: bytes, supported_versions: List[int] +) -> bytes: + buf = Buffer( + capacity=7 + + len(destination_cid) + + len(source_cid) + + 4 * len(supported_versions) + ) + buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER) + buf.push_uint32(QuicProtocolVersion.NEGOTIATION) + buf.push_uint8(len(destination_cid)) + buf.push_bytes(destination_cid) + buf.push_uint8(len(source_cid)) + buf.push_bytes(source_cid) + for version in supported_versions: + buf.push_uint32(version) + return buf.data + + +# TLS EXTENSION + + +@dataclass +class QuicPreferredAddress: + ipv4_address: Optional[Tuple[str, int]] + ipv6_address: Optional[Tuple[str, int]] + connection_id: bytes + stateless_reset_token: bytes + + +@dataclass +class QuicTransportParameters: + original_connection_id: Optional[bytes] = None + idle_timeout: Optional[int] = None + stateless_reset_token: Optional[bytes] = None + max_packet_size: Optional[int] = None + initial_max_data: Optional[int] = None + initial_max_stream_data_bidi_local: Optional[int] = None + initial_max_stream_data_bidi_remote: Optional[int] = None + initial_max_stream_data_uni: Optional[int] = None + initial_max_streams_bidi: Optional[int] = None + initial_max_streams_uni: Optional[int] = None + ack_delay_exponent: Optional[int] = None + max_ack_delay: Optional[int] = None + disable_active_migration: Optional[bool] = False + preferred_address: Optional[QuicPreferredAddress] = None + active_connection_id_limit: Optional[int] = None + max_datagram_frame_size: Optional[int] = None + quantum_readiness: Optional[bytes] = None + + +PARAMS = { + 0: ("original_connection_id", bytes), + 1: ("idle_timeout", int), + 2: ("stateless_reset_token", bytes), + 3: ("max_packet_size", int), + 4: ("initial_max_data", int), + 5: ("initial_max_stream_data_bidi_local", int), + 6: ("initial_max_stream_data_bidi_remote", int), + 7: ("initial_max_stream_data_uni", int), + 8: ("initial_max_streams_bidi", int), + 9: ("initial_max_streams_uni", int), + 10: ("ack_delay_exponent", int), + 11: ("max_ack_delay", int), + 12: ("disable_active_migration", bool), + 13: ("preferred_address", QuicPreferredAddress), + 14: ("active_connection_id_limit", int), + 32: ("max_datagram_frame_size", int), + 3127: ("quantum_readiness", bytes), +} + + +def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress: + ipv4_address = None + ipv4_host = buf.pull_bytes(4) + ipv4_port = buf.pull_uint16() + if ipv4_host != bytes(4): + ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port) + + ipv6_address = None + ipv6_host = buf.pull_bytes(16) + ipv6_port = buf.pull_uint16() + if ipv6_host != bytes(16): + ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port) + + connection_id_length = buf.pull_uint8() + connection_id = buf.pull_bytes(connection_id_length) + stateless_reset_token = buf.pull_bytes(16) + + return QuicPreferredAddress( + ipv4_address=ipv4_address, + ipv6_address=ipv6_address, + connection_id=connection_id, + stateless_reset_token=stateless_reset_token, + ) + + +def push_quic_preferred_address( + buf: Buffer, preferred_address: QuicPreferredAddress +) -> None: + if preferred_address.ipv4_address is not None: + buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed) + buf.push_uint16(preferred_address.ipv4_address[1]) + else: + buf.push_bytes(bytes(6)) + + if preferred_address.ipv6_address is not None: + buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed) + buf.push_uint16(preferred_address.ipv6_address[1]) + else: + buf.push_bytes(bytes(18)) + + buf.push_uint8(len(preferred_address.connection_id)) + buf.push_bytes(preferred_address.connection_id) + buf.push_bytes(preferred_address.stateless_reset_token) + + +def pull_quic_transport_parameters( + buf: Buffer, protocol_version: int +) -> QuicTransportParameters: + params = QuicTransportParameters() + + if protocol_version < QuicProtocolVersion.DRAFT_27: + with pull_block(buf, 2) as length: + end = buf.tell() + length + while buf.tell() < end: + param_id = buf.pull_uint16() + param_len = buf.pull_uint16() + param_start = buf.tell() + if param_id in PARAMS: + # parse known parameter + param_name, param_type = PARAMS[param_id] + if param_type == int: + setattr(params, param_name, buf.pull_uint_var()) + elif param_type == bytes: + setattr(params, param_name, buf.pull_bytes(param_len)) + elif param_type == QuicPreferredAddress: + setattr(params, param_name, pull_quic_preferred_address(buf)) + else: + setattr(params, param_name, True) + else: + # skip unknown parameter + buf.pull_bytes(param_len) + assert buf.tell() == param_start + param_len + else: + while not buf.eof(): + param_id = buf.pull_uint_var() + param_len = buf.pull_uint_var() + param_start = buf.tell() + if param_id in PARAMS: + # parse known parameter + param_name, param_type = PARAMS[param_id] + if param_type == int: + setattr(params, param_name, buf.pull_uint_var()) + elif param_type == bytes: + setattr(params, param_name, buf.pull_bytes(param_len)) + elif param_type == QuicPreferredAddress: + setattr(params, param_name, pull_quic_preferred_address(buf)) + else: + setattr(params, param_name, True) + else: + # skip unknown parameter + buf.pull_bytes(param_len) + assert buf.tell() == param_start + param_len + + return params + + +def push_quic_transport_parameters( + buf: Buffer, params: QuicTransportParameters, protocol_version: int +) -> None: + if protocol_version < QuicProtocolVersion.DRAFT_27: + with push_block(buf, 2): + for param_id, (param_name, param_type) in PARAMS.items(): + param_value = getattr(params, param_name) + if param_value is not None and param_value is not False: + buf.push_uint16(param_id) + with push_block(buf, 2): + if param_type == int: + buf.push_uint_var(param_value) + elif param_type == bytes: + buf.push_bytes(param_value) + elif param_type == QuicPreferredAddress: + push_quic_preferred_address(buf, param_value) + else: + for param_id, (param_name, param_type) in PARAMS.items(): + param_value = getattr(params, param_name) + if param_value is not None and param_value is not False: + param_buf = Buffer(capacity=65536) + if param_type == int: + param_buf.push_uint_var(param_value) + elif param_type == bytes: + param_buf.push_bytes(param_value) + elif param_type == QuicPreferredAddress: + push_quic_preferred_address(param_buf, param_value) + buf.push_uint_var(param_id) + buf.push_uint_var(param_buf.tell()) + buf.push_bytes(param_buf.data) + + +# FRAMES + + +class QuicFrameType(IntEnum): + PADDING = 0x00 + PING = 0x01 + ACK = 0x02 + ACK_ECN = 0x03 + RESET_STREAM = 0x04 + STOP_SENDING = 0x05 + CRYPTO = 0x06 + NEW_TOKEN = 0x07 + STREAM_BASE = 0x08 + MAX_DATA = 0x10 + MAX_STREAM_DATA = 0x11 + MAX_STREAMS_BIDI = 0x12 + MAX_STREAMS_UNI = 0x13 + DATA_BLOCKED = 0x14 + STREAM_DATA_BLOCKED = 0x15 + STREAMS_BLOCKED_BIDI = 0x16 + STREAMS_BLOCKED_UNI = 0x17 + NEW_CONNECTION_ID = 0x18 + RETIRE_CONNECTION_ID = 0x19 + PATH_CHALLENGE = 0x1A + PATH_RESPONSE = 0x1B + TRANSPORT_CLOSE = 0x1C + APPLICATION_CLOSE = 0x1D + HANDSHAKE_DONE = 0x1E + DATAGRAM = 0x30 + DATAGRAM_WITH_LENGTH = 0x31 + + +NON_ACK_ELICITING_FRAME_TYPES = frozenset( + [ + QuicFrameType.ACK, + QuicFrameType.ACK_ECN, + QuicFrameType.PADDING, + QuicFrameType.TRANSPORT_CLOSE, + QuicFrameType.APPLICATION_CLOSE, + ] +) +NON_IN_FLIGHT_FRAME_TYPES = frozenset( + [ + QuicFrameType.ACK, + QuicFrameType.ACK_ECN, + QuicFrameType.TRANSPORT_CLOSE, + QuicFrameType.APPLICATION_CLOSE, + ] +) + +PROBING_FRAME_TYPES = frozenset( + [ + QuicFrameType.PATH_CHALLENGE, + QuicFrameType.PATH_RESPONSE, + QuicFrameType.PADDING, + QuicFrameType.NEW_CONNECTION_ID, + ] +) + + +@dataclass +class QuicStreamFrame: + data: bytes = b"" + fin: bool = False + offset: int = 0 + + +def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]: + rangeset = RangeSet() + end = buf.pull_uint_var() # largest acknowledged + delay = buf.pull_uint_var() + ack_range_count = buf.pull_uint_var() + ack_count = buf.pull_uint_var() # first ack range + rangeset.add(end - ack_count, end + 1) + end -= ack_count + for _ in range(ack_range_count): + end -= buf.pull_uint_var() + 2 + ack_count = buf.pull_uint_var() + rangeset.add(end - ack_count, end + 1) + end -= ack_count + return rangeset, delay + + +def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int: + ranges = len(rangeset) + index = ranges - 1 + r = rangeset[index] + buf.push_uint_var(r.stop - 1) + buf.push_uint_var(delay) + buf.push_uint_var(index) + buf.push_uint_var(r.stop - 1 - r.start) + start = r.start + while index > 0: + index -= 1 + r = rangeset[index] + buf.push_uint_var(start - r.stop - 1) + buf.push_uint_var(r.stop - r.start - 1) + start = r.start + return ranges diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet_builder.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet_builder.py new file mode 100644 index 000000000000..08837c2c3304 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/packet_builder.py @@ -0,0 +1,366 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from ..buffer import Buffer, size_uint_var +from ..tls import Epoch +from .crypto import CryptoPair +from .logger import QuicLoggerTrace +from .packet import ( + NON_ACK_ELICITING_FRAME_TYPES, + NON_IN_FLIGHT_FRAME_TYPES, + PACKET_NUMBER_MAX_SIZE, + PACKET_TYPE_HANDSHAKE, + PACKET_TYPE_INITIAL, + PACKET_TYPE_MASK, + QuicFrameType, + is_long_header, +) + +PACKET_MAX_SIZE = 1280 +PACKET_LENGTH_SEND_SIZE = 2 +PACKET_NUMBER_SEND_SIZE = 2 + + +QuicDeliveryHandler = Callable[..., None] + + +class QuicDeliveryState(Enum): + ACKED = 0 + LOST = 1 + EXPIRED = 2 + + +@dataclass +class QuicSentPacket: + epoch: Epoch + in_flight: bool + is_ack_eliciting: bool + is_crypto_packet: bool + packet_number: int + packet_type: int + sent_time: Optional[float] = None + sent_bytes: int = 0 + + delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field( + default_factory=list + ) + quic_logger_frames: List[Dict] = field(default_factory=list) + + +class QuicPacketBuilderStop(Exception): + pass + + +class QuicPacketBuilder: + """ + Helper for building QUIC packets. + """ + + def __init__( + self, + *, + host_cid: bytes, + peer_cid: bytes, + version: int, + is_client: bool, + packet_number: int = 0, + peer_token: bytes = b"", + quic_logger: Optional[QuicLoggerTrace] = None, + spin_bit: bool = False, + ): + self.max_flight_bytes: Optional[int] = None + self.max_total_bytes: Optional[int] = None + self.quic_logger_frames: Optional[List[Dict]] = None + + self._host_cid = host_cid + self._is_client = is_client + self._peer_cid = peer_cid + self._peer_token = peer_token + self._quic_logger = quic_logger + self._spin_bit = spin_bit + self._version = version + + # assembled datagrams and packets + self._datagrams: List[bytes] = [] + self._datagram_flight_bytes = 0 + self._datagram_init = True + self._packets: List[QuicSentPacket] = [] + self._flight_bytes = 0 + self._total_bytes = 0 + + # current packet + self._header_size = 0 + self._packet: Optional[QuicSentPacket] = None + self._packet_crypto: Optional[CryptoPair] = None + self._packet_long_header = False + self._packet_number = packet_number + self._packet_start = 0 + self._packet_type = 0 + + self._buffer = Buffer(PACKET_MAX_SIZE) + self._buffer_capacity = PACKET_MAX_SIZE + self._flight_capacity = PACKET_MAX_SIZE + + @property + def packet_is_empty(self) -> bool: + """ + Returns `True` if the current packet is empty. + """ + assert self._packet is not None + packet_size = self._buffer.tell() - self._packet_start + return packet_size <= self._header_size + + @property + def packet_number(self) -> int: + """ + Returns the packet number for the next packet. + """ + return self._packet_number + + @property + def remaining_buffer_space(self) -> int: + """ + Returns the remaining number of bytes which can be used in + the current packet. + """ + return ( + self._buffer_capacity + - self._buffer.tell() + - self._packet_crypto.aead_tag_size + ) + + @property + def remaining_flight_space(self) -> int: + """ + Returns the remaining number of bytes which can be used in + the current packet. + """ + return ( + self._flight_capacity + - self._buffer.tell() + - self._packet_crypto.aead_tag_size + ) + + def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]: + """ + Returns the assembled datagrams. + """ + if self._packet is not None: + self._end_packet() + self._flush_current_datagram() + + datagrams = self._datagrams + packets = self._packets + self._datagrams = [] + self._packets = [] + return datagrams, packets + + def start_frame( + self, + frame_type: int, + capacity: int = 1, + handler: Optional[QuicDeliveryHandler] = None, + handler_args: Sequence[Any] = [], + ) -> Buffer: + """ + Starts a new frame. + """ + if self.remaining_buffer_space < capacity or ( + frame_type not in NON_IN_FLIGHT_FRAME_TYPES + and self.remaining_flight_space < capacity + ): + raise QuicPacketBuilderStop + + self._buffer.push_uint_var(frame_type) + if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: + self._packet.is_ack_eliciting = True + if frame_type not in NON_IN_FLIGHT_FRAME_TYPES: + self._packet.in_flight = True + if frame_type == QuicFrameType.CRYPTO: + self._packet.is_crypto_packet = True + if handler is not None: + self._packet.delivery_handlers.append((handler, handler_args)) + return self._buffer + + def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: + """ + Starts a new packet. + """ + buf = self._buffer + + # finish previous datagram + if self._packet is not None: + self._end_packet() + + # if there is too little space remaining, start a new datagram + # FIXME: the limit is arbitrary! + packet_start = buf.tell() + if self._buffer_capacity - packet_start < 128: + self._flush_current_datagram() + packet_start = 0 + + # initialize datagram if needed + if self._datagram_init: + if self.max_total_bytes is not None: + remaining_total_bytes = self.max_total_bytes - self._total_bytes + if remaining_total_bytes < self._buffer_capacity: + self._buffer_capacity = remaining_total_bytes + + self._flight_capacity = self._buffer_capacity + if self.max_flight_bytes is not None: + remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes + if remaining_flight_bytes < self._flight_capacity: + self._flight_capacity = remaining_flight_bytes + self._datagram_flight_bytes = 0 + self._datagram_init = False + + # calculate header size + packet_long_header = is_long_header(packet_type) + if packet_long_header: + header_size = 11 + len(self._peer_cid) + len(self._host_cid) + if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + token_length = len(self._peer_token) + header_size += size_uint_var(token_length) + token_length + else: + header_size = 3 + len(self._peer_cid) + + # check we have enough space + if packet_start + header_size >= self._buffer_capacity: + raise QuicPacketBuilderStop + + # determine ack epoch + if packet_type == PACKET_TYPE_INITIAL: + epoch = Epoch.INITIAL + elif packet_type == PACKET_TYPE_HANDSHAKE: + epoch = Epoch.HANDSHAKE + else: + epoch = Epoch.ONE_RTT + + self._header_size = header_size + self._packet = QuicSentPacket( + epoch=epoch, + in_flight=False, + is_ack_eliciting=False, + is_crypto_packet=False, + packet_number=self._packet_number, + packet_type=packet_type, + ) + self._packet_crypto = crypto + self._packet_long_header = packet_long_header + self._packet_start = packet_start + self._packet_type = packet_type + self.quic_logger_frames = self._packet.quic_logger_frames + + buf.seek(self._packet_start + self._header_size) + + def _end_packet(self) -> None: + """ + Ends the current packet. + """ + buf = self._buffer + packet_size = buf.tell() - self._packet_start + if packet_size > self._header_size: + # pad initial datagram + if ( + self._is_client + and self._packet_type == PACKET_TYPE_INITIAL + and self._packet.is_crypto_packet + ): + if self.remaining_flight_space: + buf.push_bytes(bytes(self.remaining_flight_space)) + packet_size = buf.tell() - self._packet_start + self._packet.in_flight = True + + # log frame + if self._quic_logger is not None: + self._packet.quic_logger_frames.append( + self._quic_logger.encode_padding_frame() + ) + + # write header + if self._packet_long_header: + length = ( + packet_size + - self._header_size + + PACKET_NUMBER_SEND_SIZE + + self._packet_crypto.aead_tag_size + ) + + buf.seek(self._packet_start) + buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1)) + buf.push_uint32(self._version) + buf.push_uint8(len(self._peer_cid)) + buf.push_bytes(self._peer_cid) + buf.push_uint8(len(self._host_cid)) + buf.push_bytes(self._host_cid) + if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + buf.push_uint_var(len(self._peer_token)) + buf.push_bytes(self._peer_token) + buf.push_uint16(length | 0x4000) + buf.push_uint16(self._packet_number & 0xFFFF) + else: + buf.seek(self._packet_start) + buf.push_uint8( + self._packet_type + | (self._spin_bit << 5) + | (self._packet_crypto.key_phase << 2) + | (PACKET_NUMBER_SEND_SIZE - 1) + ) + buf.push_bytes(self._peer_cid) + buf.push_uint16(self._packet_number & 0xFFFF) + + # check whether we need padding + padding_size = ( + PACKET_NUMBER_MAX_SIZE + - PACKET_NUMBER_SEND_SIZE + + self._header_size + - packet_size + ) + if padding_size > 0: + buf.seek(self._packet_start + packet_size) + buf.push_bytes(bytes(padding_size)) + packet_size += padding_size + self._packet.in_flight = True + + # log frame + if self._quic_logger is not None: + self._packet.quic_logger_frames.append( + self._quic_logger.encode_padding_frame() + ) + + # encrypt in place + plain = buf.data_slice(self._packet_start, self._packet_start + packet_size) + buf.seek(self._packet_start) + buf.push_bytes( + self._packet_crypto.encrypt_packet( + plain[0 : self._header_size], + plain[self._header_size : packet_size], + self._packet_number, + ) + ) + self._packet.sent_bytes = buf.tell() - self._packet_start + self._packets.append(self._packet) + if self._packet.in_flight: + self._datagram_flight_bytes += self._packet.sent_bytes + + # short header packets cannot be coallesced, we need a new datagram + if not self._packet_long_header: + self._flush_current_datagram() + + self._packet_number += 1 + else: + # "cancel" the packet + buf.seek(self._packet_start) + + self._packet = None + self.quic_logger_frames = None + + def _flush_current_datagram(self) -> None: + datagram_bytes = self._buffer.tell() + if datagram_bytes: + self._datagrams.append(self._buffer.data) + self._flight_bytes += self._datagram_flight_bytes + self._total_bytes += datagram_bytes + self._datagram_init = True + self._buffer.seek(0) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/rangeset.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/rangeset.py new file mode 100644 index 000000000000..86086c9c74f7 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/rangeset.py @@ -0,0 +1,98 @@ +from collections.abc import Sequence +from typing import Any, Iterable, List, Optional + + +class RangeSet(Sequence): + def __init__(self, ranges: Iterable[range] = []): + self.__ranges: List[range] = [] + for r in ranges: + assert r.step == 1 + self.add(r.start, r.stop) + + def add(self, start: int, stop: Optional[int] = None) -> None: + if stop is None: + stop = start + 1 + assert stop > start + + for i, r in enumerate(self.__ranges): + # the added range is entirely before current item, insert here + if stop < r.start: + self.__ranges.insert(i, range(start, stop)) + return + + # the added range is entirely after current item, keep looking + if start > r.stop: + continue + + # the added range touches the current item, merge it + start = min(start, r.start) + stop = max(stop, r.stop) + while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop: + stop = max(self.__ranges[i + 1].stop, stop) + self.__ranges.pop(i + 1) + self.__ranges[i] = range(start, stop) + return + + # the added range is entirely after all existing items, append it + self.__ranges.append(range(start, stop)) + + def bounds(self) -> range: + return range(self.__ranges[0].start, self.__ranges[-1].stop) + + def shift(self) -> range: + return self.__ranges.pop(0) + + def subtract(self, start: int, stop: int) -> None: + assert stop > start + + i = 0 + while i < len(self.__ranges): + r = self.__ranges[i] + + # the removed range is entirely before current item, stop here + if stop <= r.start: + return + + # the removed range is entirely after current item, keep looking + if start >= r.stop: + i += 1 + continue + + # the removed range completely covers the current item, remove it + if start <= r.start and stop >= r.stop: + self.__ranges.pop(i) + continue + + # the removed range touches the current item + if start > r.start: + self.__ranges[i] = range(r.start, start) + if stop < r.stop: + self.__ranges.insert(i + 1, range(stop, r.stop)) + else: + self.__ranges[i] = range(stop, r.stop) + + i += 1 + + def __bool__(self) -> bool: + raise NotImplementedError + + def __contains__(self, val: Any) -> bool: + for r in self.__ranges: + if val in r: + return True + return False + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RangeSet): + return NotImplemented + + return self.__ranges == other.__ranges + + def __getitem__(self, key: Any) -> range: + return self.__ranges[key] + + def __len__(self) -> int: + return len(self.__ranges) + + def __repr__(self) -> str: + return "RangeSet({})".format(repr(self.__ranges)) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/recovery.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/recovery.py new file mode 100644 index 000000000000..9abeb1ecf566 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/recovery.py @@ -0,0 +1,497 @@ +import math +from typing import Callable, Dict, Iterable, List, Optional + +from .logger import QuicLoggerTrace +from .packet_builder import QuicDeliveryState, QuicSentPacket +from .rangeset import RangeSet + +# loss detection +K_PACKET_THRESHOLD = 3 +K_INITIAL_RTT = 0.5 # seconds +K_GRANULARITY = 0.001 # seconds +K_TIME_THRESHOLD = 9 / 8 +K_MICRO_SECOND = 0.000001 +K_SECOND = 1.0 + +# congestion control +K_MAX_DATAGRAM_SIZE = 1280 +K_INITIAL_WINDOW = 10 * K_MAX_DATAGRAM_SIZE +K_MINIMUM_WINDOW = 2 * K_MAX_DATAGRAM_SIZE +K_LOSS_REDUCTION_FACTOR = 0.5 + + +class QuicPacketSpace: + def __init__(self) -> None: + self.ack_at: Optional[float] = None + self.ack_queue = RangeSet() + self.discarded = False + self.expected_packet_number = 0 + self.largest_received_packet = -1 + self.largest_received_time: Optional[float] = None + + # sent packets and loss + self.ack_eliciting_in_flight = 0 + self.largest_acked_packet = 0 + self.loss_time: Optional[float] = None + self.sent_packets: Dict[int, QuicSentPacket] = {} + + +class QuicPacketPacer: + def __init__(self) -> None: + self.bucket_max: float = 0.0 + self.bucket_time: float = 0.0 + self.evaluation_time: float = 0.0 + self.packet_time: Optional[float] = None + + def next_send_time(self, now: float) -> float: + if self.packet_time is not None: + self.update_bucket(now=now) + if self.bucket_time <= 0: + return now + self.packet_time + return None + + def update_after_send(self, now: float) -> None: + if self.packet_time is not None: + self.update_bucket(now=now) + if self.bucket_time < self.packet_time: + self.bucket_time = 0.0 + else: + self.bucket_time -= self.packet_time + + def update_bucket(self, now: float) -> None: + if now > self.evaluation_time: + self.bucket_time = min( + self.bucket_time + (now - self.evaluation_time), self.bucket_max + ) + self.evaluation_time = now + + def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: + pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND) + self.packet_time = max( + K_MICRO_SECOND, min(K_MAX_DATAGRAM_SIZE / pacing_rate, K_SECOND) + ) + + self.bucket_max = ( + max( + 2 * K_MAX_DATAGRAM_SIZE, + min(congestion_window // 4, 16 * K_MAX_DATAGRAM_SIZE), + ) + / pacing_rate + ) + if self.bucket_time > self.bucket_max: + self.bucket_time = self.bucket_max + + +class QuicCongestionControl: + """ + New Reno congestion control. + """ + + def __init__(self) -> None: + self.bytes_in_flight = 0 + self.congestion_window = K_INITIAL_WINDOW + self._congestion_recovery_start_time = 0.0 + self._congestion_stash = 0 + self._rtt_monitor = QuicRttMonitor() + self.ssthresh: Optional[int] = None + + def on_packet_acked(self, packet: QuicSentPacket) -> None: + self.bytes_in_flight -= packet.sent_bytes + + # don't increase window in congestion recovery + if packet.sent_time <= self._congestion_recovery_start_time: + return + + if self.ssthresh is None or self.congestion_window < self.ssthresh: + # slow start + self.congestion_window += packet.sent_bytes + else: + # congestion avoidance + self._congestion_stash += packet.sent_bytes + count = self._congestion_stash // self.congestion_window + if count: + self._congestion_stash -= count * self.congestion_window + self.congestion_window += count * K_MAX_DATAGRAM_SIZE + + def on_packet_sent(self, packet: QuicSentPacket) -> None: + self.bytes_in_flight += packet.sent_bytes + + def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + + def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: + lost_largest_time = 0.0 + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + lost_largest_time = packet.sent_time + + # start a new congestion event if packet was sent after the + # start of the previous congestion recovery period. + if lost_largest_time > self._congestion_recovery_start_time: + self._congestion_recovery_start_time = now + self.congestion_window = max( + int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW + ) + self.ssthresh = self.congestion_window + + # TODO : collapse congestion window if persistent congestion + + def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: + # check whether we should exit slow start + if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( + latest_rtt, now + ): + self.ssthresh = self.congestion_window + + +class QuicPacketRecovery: + """ + Packet loss and congestion controller. + """ + + def __init__( + self, + is_client_without_1rtt: bool, + send_probe: Callable[[], None], + quic_logger: Optional[QuicLoggerTrace] = None, + ) -> None: + self.is_client_without_1rtt = is_client_without_1rtt + self.max_ack_delay = 0.025 + self.spaces: List[QuicPacketSpace] = [] + + # callbacks + self._quic_logger = quic_logger + self._send_probe = send_probe + + # loss detection + self._pto_count = 0 + self._rtt_initialized = False + self._rtt_latest = 0.0 + self._rtt_min = math.inf + self._rtt_smoothed = 0.0 + self._rtt_variance = 0.0 + self._time_of_last_sent_ack_eliciting_packet = 0.0 + + # congestion control + self._cc = QuicCongestionControl() + self._pacer = QuicPacketPacer() + + @property + def bytes_in_flight(self) -> int: + return self._cc.bytes_in_flight + + @property + def congestion_window(self) -> int: + return self._cc.congestion_window + + def discard_space(self, space: QuicPacketSpace) -> None: + assert space in self.spaces + + self._cc.on_packets_expired( + filter(lambda x: x.in_flight, space.sent_packets.values()) + ) + space.sent_packets.clear() + + space.ack_at = None + space.ack_eliciting_in_flight = 0 + space.loss_time = None + + if self._quic_logger is not None: + self._log_metrics_updated() + + def get_earliest_loss_space(self) -> Optional[QuicPacketSpace]: + loss_space = None + for space in self.spaces: + if space.loss_time is not None and ( + loss_space is None or space.loss_time < loss_space.loss_time + ): + loss_space = space + return loss_space + + def get_loss_detection_time(self) -> float: + # loss timer + loss_space = self.get_earliest_loss_space() + if loss_space is not None: + return loss_space.loss_time + + # packet timer + if ( + self.is_client_without_1rtt + or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0 + ): + if not self._rtt_initialized: + timeout = 2 * K_INITIAL_RTT * (2 ** self._pto_count) + else: + timeout = self.get_probe_timeout() * (2 ** self._pto_count) + return self._time_of_last_sent_ack_eliciting_packet + timeout + + return None + + def get_probe_timeout(self) -> float: + return ( + self._rtt_smoothed + + max(4 * self._rtt_variance, K_GRANULARITY) + + self.max_ack_delay + ) + + def on_ack_received( + self, + space: QuicPacketSpace, + ack_rangeset: RangeSet, + ack_delay: float, + now: float, + ) -> None: + """ + Update metrics as the result of an ACK being received. + """ + is_ack_eliciting = False + largest_acked = ack_rangeset.bounds().stop - 1 + largest_newly_acked = None + largest_sent_time = None + + if largest_acked > space.largest_acked_packet: + space.largest_acked_packet = largest_acked + + for packet_number in sorted(space.sent_packets.keys()): + if packet_number > largest_acked: + break + if packet_number in ack_rangeset: + # remove packet and update counters + packet = space.sent_packets.pop(packet_number) + if packet.is_ack_eliciting: + is_ack_eliciting = True + space.ack_eliciting_in_flight -= 1 + if packet.in_flight: + self._cc.on_packet_acked(packet) + largest_newly_acked = packet_number + largest_sent_time = packet.sent_time + + # trigger callbacks + for handler, args in packet.delivery_handlers: + handler(QuicDeliveryState.ACKED, *args) + + # nothing to do if there are no newly acked packets + if largest_newly_acked is None: + return + + if largest_acked == largest_newly_acked and is_ack_eliciting: + latest_rtt = now - largest_sent_time + log_rtt = True + + # limit ACK delay to max_ack_delay + ack_delay = min(ack_delay, self.max_ack_delay) + + # update RTT estimate, which cannot be < 1 ms + self._rtt_latest = max(latest_rtt, 0.001) + if self._rtt_latest < self._rtt_min: + self._rtt_min = self._rtt_latest + if self._rtt_latest > self._rtt_min + ack_delay: + self._rtt_latest -= ack_delay + + if not self._rtt_initialized: + self._rtt_initialized = True + self._rtt_variance = latest_rtt / 2 + self._rtt_smoothed = latest_rtt + else: + self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs( + self._rtt_min - self._rtt_latest + ) + self._rtt_smoothed = ( + 7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest + ) + + # inform congestion controller + self._cc.on_rtt_measurement(latest_rtt, now=now) + self._pacer.update_rate( + congestion_window=self._cc.congestion_window, + smoothed_rtt=self._rtt_smoothed, + ) + + else: + log_rtt = False + + self._detect_loss(space, now=now) + + if self._quic_logger is not None: + self._log_metrics_updated(log_rtt=log_rtt) + + self._pto_count = 0 + + def on_loss_detection_timeout(self, now: float) -> None: + loss_space = self.get_earliest_loss_space() + if loss_space is not None: + self._detect_loss(loss_space, now=now) + else: + self._pto_count += 1 + + # reschedule some data + for space in self.spaces: + self._on_packets_lost( + tuple( + filter( + lambda i: i.is_crypto_packet, space.sent_packets.values() + ) + ), + space=space, + now=now, + ) + + self._send_probe() + + def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None: + space.sent_packets[packet.packet_number] = packet + + if packet.is_ack_eliciting: + space.ack_eliciting_in_flight += 1 + if packet.in_flight: + if packet.is_ack_eliciting: + self._time_of_last_sent_ack_eliciting_packet = packet.sent_time + + # add packet to bytes in flight + self._cc.on_packet_sent(packet) + + if self._quic_logger is not None: + self._log_metrics_updated() + + def _detect_loss(self, space: QuicPacketSpace, now: float) -> None: + """ + Check whether any packets should be declared lost. + """ + loss_delay = K_TIME_THRESHOLD * ( + max(self._rtt_latest, self._rtt_smoothed) + if self._rtt_initialized + else K_INITIAL_RTT + ) + packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD + time_threshold = now - loss_delay + + lost_packets = [] + space.loss_time = None + for packet_number, packet in space.sent_packets.items(): + if packet_number > space.largest_acked_packet: + break + + if packet_number <= packet_threshold or packet.sent_time <= time_threshold: + lost_packets.append(packet) + else: + packet_loss_time = packet.sent_time + loss_delay + if space.loss_time is None or space.loss_time > packet_loss_time: + space.loss_time = packet_loss_time + + self._on_packets_lost(lost_packets, space=space, now=now) + + def _log_metrics_updated(self, log_rtt=False) -> None: + data = { + "bytes_in_flight": self._cc.bytes_in_flight, + "cwnd": self._cc.congestion_window, + } + if self._cc.ssthresh is not None: + data["ssthresh"] = self._cc.ssthresh + + if log_rtt: + data.update( + { + "latest_rtt": self._quic_logger.encode_time(self._rtt_latest), + "min_rtt": self._quic_logger.encode_time(self._rtt_min), + "smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed), + "rtt_variance": self._quic_logger.encode_time(self._rtt_variance), + } + ) + + self._quic_logger.log_event( + category="recovery", event="metrics_updated", data=data + ) + + def _on_packets_lost( + self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float + ) -> None: + lost_packets_cc = [] + for packet in packets: + del space.sent_packets[packet.packet_number] + + if packet.in_flight: + lost_packets_cc.append(packet) + + if packet.is_ack_eliciting: + space.ack_eliciting_in_flight -= 1 + + if self._quic_logger is not None: + self._quic_logger.log_event( + category="recovery", + event="packet_lost", + data={ + "type": self._quic_logger.packet_type(packet.packet_type), + "packet_number": str(packet.packet_number), + }, + ) + self._log_metrics_updated() + + # trigger callbacks + for handler, args in packet.delivery_handlers: + handler(QuicDeliveryState.LOST, *args) + + # inform congestion controller + if lost_packets_cc: + self._cc.on_packets_lost(lost_packets_cc, now=now) + self._pacer.update_rate( + congestion_window=self._cc.congestion_window, + smoothed_rtt=self._rtt_smoothed, + ) + if self._quic_logger is not None: + self._log_metrics_updated() + + +class QuicRttMonitor: + """ + Roundtrip time monitor for HyStart. + """ + + def __init__(self) -> None: + self._increases = 0 + self._last_time = None + self._ready = False + self._size = 5 + + self._filtered_min: Optional[float] = None + + self._sample_idx = 0 + self._sample_max: Optional[float] = None + self._sample_min: Optional[float] = None + self._sample_time = 0.0 + self._samples = [0.0 for i in range(self._size)] + + def add_rtt(self, rtt: float) -> None: + self._samples[self._sample_idx] = rtt + self._sample_idx += 1 + + if self._sample_idx >= self._size: + self._sample_idx = 0 + self._ready = True + + if self._ready: + self._sample_max = self._samples[0] + self._sample_min = self._samples[0] + for sample in self._samples[1:]: + if sample < self._sample_min: + self._sample_min = sample + elif sample > self._sample_max: + self._sample_max = sample + + def is_rtt_increasing(self, rtt: float, now: float) -> bool: + if now > self._sample_time + K_GRANULARITY: + self.add_rtt(rtt) + self._sample_time = now + + if self._ready: + if self._filtered_min is None or self._filtered_min > self._sample_max: + self._filtered_min = self._sample_max + + delta = self._sample_min - self._filtered_min + if delta * 4 >= self._filtered_min: + self._increases += 1 + if self._increases >= self._size: + return True + elif delta > 0: + self._increases = 0 + return False diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/retry.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/retry.py new file mode 100644 index 000000000000..c45c3a5c55c9 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/retry.py @@ -0,0 +1,39 @@ +import ipaddress + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from .connection import NetworkAddress + + +def encode_address(addr: NetworkAddress) -> bytes: + return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF]) + + +class QuicRetryTokenHandler: + def __init__(self) -> None: + self._key = rsa.generate_private_key( + public_exponent=65537, key_size=1024, backend=default_backend() + ) + + def create_token(self, addr: NetworkAddress, destination_cid: bytes) -> bytes: + retry_message = encode_address(addr) + b"|" + destination_cid + return self._key.public_key().encrypt( + retry_message, + padding.OAEP( + mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) + + def validate_token(self, addr: NetworkAddress, token: bytes) -> bytes: + retry_message = self._key.decrypt( + token, + padding.OAEP( + mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), + ) + encoded_addr, original_connection_id = retry_message.split(b"|", maxsplit=1) + if encoded_addr != encode_address(addr): + raise ValueError("Remote address does not match.") + return original_connection_id diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/stream.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/stream.py new file mode 100644 index 000000000000..a623686f5d6b --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/quic/stream.py @@ -0,0 +1,214 @@ +from typing import Optional + +from . import events +from .packet import QuicStreamFrame +from .packet_builder import QuicDeliveryState +from .rangeset import RangeSet + + +class QuicStream: + def __init__( + self, + stream_id: Optional[int] = None, + max_stream_data_local: int = 0, + max_stream_data_remote: int = 0, + ) -> None: + self.is_blocked = False + self.max_stream_data_local = max_stream_data_local + self.max_stream_data_local_sent = max_stream_data_local + self.max_stream_data_remote = max_stream_data_remote + self.send_buffer_is_empty = True + + self._recv_buffer = bytearray() + self._recv_buffer_fin: Optional[int] = None + self._recv_buffer_start = 0 # the offset for the start of the buffer + self._recv_highest = 0 # the highest offset ever seen + self._recv_ranges = RangeSet() + + self._send_acked = RangeSet() + self._send_buffer = bytearray() + self._send_buffer_fin: Optional[int] = None + self._send_buffer_start = 0 # the offset for the start of the buffer + self._send_buffer_stop = 0 # the offset for the stop of the buffer + self._send_highest = 0 + self._send_pending = RangeSet() + self._send_pending_eof = False + + self.__stream_id = stream_id + + @property + def stream_id(self) -> Optional[int]: + return self.__stream_id + + # reader + + def add_frame(self, frame: QuicStreamFrame) -> Optional[events.StreamDataReceived]: + """ + Add a frame of received data. + """ + pos = frame.offset - self._recv_buffer_start + count = len(frame.data) + frame_end = frame.offset + count + + # we should receive no more data beyond FIN! + if self._recv_buffer_fin is not None and frame_end > self._recv_buffer_fin: + raise Exception("Data received beyond FIN") + if frame.fin: + self._recv_buffer_fin = frame_end + if frame_end > self._recv_highest: + self._recv_highest = frame_end + + # fast path: new in-order chunk + if pos == 0 and count and not self._recv_buffer: + self._recv_buffer_start += count + return events.StreamDataReceived( + data=frame.data, end_stream=frame.fin, stream_id=self.__stream_id + ) + + # discard duplicate data + if pos < 0: + frame.data = frame.data[-pos:] + frame.offset -= pos + pos = 0 + + # marked received range + if frame_end > frame.offset: + self._recv_ranges.add(frame.offset, frame_end) + + # add new data + gap = pos - len(self._recv_buffer) + if gap > 0: + self._recv_buffer += bytearray(gap) + self._recv_buffer[pos : pos + count] = frame.data + + # return data from the front of the buffer + data = self._pull_data() + end_stream = self._recv_buffer_start == self._recv_buffer_fin + if data or end_stream: + return events.StreamDataReceived( + data=data, end_stream=end_stream, stream_id=self.__stream_id + ) + else: + return None + + def _pull_data(self) -> bytes: + """ + Remove data from the front of the buffer. + """ + try: + has_data_to_read = self._recv_ranges[0].start == self._recv_buffer_start + except IndexError: + has_data_to_read = False + if not has_data_to_read: + return b"" + + r = self._recv_ranges.shift() + pos = r.stop - r.start + data = bytes(self._recv_buffer[:pos]) + del self._recv_buffer[:pos] + self._recv_buffer_start = r.stop + return data + + # writer + + @property + def next_send_offset(self) -> int: + """ + The offset for the next frame to send. + + This is used to determine the space needed for the frame's `offset` field. + """ + try: + return self._send_pending[0].start + except IndexError: + return self._send_buffer_stop + + def get_frame( + self, max_size: int, max_offset: Optional[int] = None + ) -> Optional[QuicStreamFrame]: + """ + Get a frame of data to send. + """ + # get the first pending data range + try: + r = self._send_pending[0] + except IndexError: + if self._send_pending_eof: + # FIN only + self._send_pending_eof = False + return QuicStreamFrame(fin=True, offset=self._send_buffer_fin) + + self.send_buffer_is_empty = True + return None + + # apply flow control + start = r.start + stop = min(r.stop, start + max_size) + if max_offset is not None and stop > max_offset: + stop = max_offset + if stop <= start: + return None + + # create frame + frame = QuicStreamFrame( + data=bytes( + self._send_buffer[ + start - self._send_buffer_start : stop - self._send_buffer_start + ] + ), + offset=start, + ) + self._send_pending.subtract(start, stop) + + # track the highest offset ever sent + if stop > self._send_highest: + self._send_highest = stop + + # if the buffer is empty and EOF was written, set the FIN bit + if self._send_buffer_fin == stop: + frame.fin = True + self._send_pending_eof = False + + return frame + + def on_data_delivery( + self, delivery: QuicDeliveryState, start: int, stop: int + ) -> None: + """ + Callback when sent data is ACK'd. + """ + self.send_buffer_is_empty = False + if delivery == QuicDeliveryState.ACKED: + if stop > start: + self._send_acked.add(start, stop) + first_range = self._send_acked[0] + if first_range.start == self._send_buffer_start: + size = first_range.stop - first_range.start + self._send_acked.shift() + self._send_buffer_start += size + del self._send_buffer[:size] + else: + if stop > start: + self._send_pending.add(start, stop) + if stop == self._send_buffer_fin: + self.send_buffer_empty = False + self._send_pending_eof = True + + def write(self, data: bytes, end_stream: bool = False) -> None: + """ + Write some data bytes to the QUIC stream. + """ + assert self._send_buffer_fin is None, "cannot call write() after FIN" + size = len(data) + + if size: + self.send_buffer_is_empty = False + self._send_pending.add( + self._send_buffer_stop, self._send_buffer_stop + size + ) + self._send_buffer += data + self._send_buffer_stop += size + if end_stream: + self.send_buffer_is_empty = False + self._send_buffer_fin = self._send_buffer_stop + self._send_pending_eof = True diff --git a/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/tls.py b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/tls.py new file mode 100644 index 000000000000..30b69ff2d322 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/src/aioquic/tls.py @@ -0,0 +1,1928 @@ +import datetime +import logging +import os +import ssl +import struct +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum, IntEnum +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +import certifi +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.bindings.openssl.binding import Binding +from cryptography.hazmat.primitives import hashes, hmac, serialization +from cryptography.hazmat.primitives.asymmetric import ( + dsa, + ec, + padding, + rsa, + x448, + x25519, +) +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat + +from .buffer import Buffer + +binding = Binding() +binding.init_static_locks() +ffi = binding.ffi +lib = binding.lib + +TLS_VERSION_1_2 = 0x0303 +TLS_VERSION_1_3 = 0x0304 +TLS_VERSION_1_3_DRAFT_28 = 0x7F1C +TLS_VERSION_1_3_DRAFT_27 = 0x7F1B +TLS_VERSION_1_3_DRAFT_26 = 0x7F1A + +T = TypeVar("T") + +# facilitate mocking for the test suite +utcnow = datetime.datetime.utcnow + + +class AlertDescription(IntEnum): + close_notify = 0 + unexpected_message = 10 + bad_record_mac = 20 + record_overflow = 22 + handshake_failure = 40 + bad_certificate = 42 + unsupported_certificate = 43 + certificate_revoked = 44 + certificate_expired = 45 + certificate_unknown = 46 + illegal_parameter = 47 + unknown_ca = 48 + access_denied = 49 + decode_error = 50 + decrypt_error = 51 + protocol_version = 70 + insufficient_security = 71 + internal_error = 80 + inappropriate_fallback = 86 + user_canceled = 90 + missing_extension = 109 + unsupported_extension = 110 + unrecognized_name = 112 + bad_certificate_status_response = 113 + unknown_psk_identity = 115 + certificate_required = 116 + no_application_protocol = 120 + + +class Alert(Exception): + description: AlertDescription + + +class AlertBadCertificate(Alert): + description = AlertDescription.bad_certificate + + +class AlertCertificateExpired(Alert): + description = AlertDescription.certificate_expired + + +class AlertDecryptError(Alert): + description = AlertDescription.decrypt_error + + +class AlertHandshakeFailure(Alert): + description = AlertDescription.handshake_failure + + +class AlertIllegalParameter(Alert): + description = AlertDescription.illegal_parameter + + +class AlertInternalError(Alert): + description = AlertDescription.internal_error + + +class AlertProtocolVersion(Alert): + description = AlertDescription.protocol_version + + +class AlertUnexpectedMessage(Alert): + description = AlertDescription.unexpected_message + + +class Direction(Enum): + DECRYPT = 0 + ENCRYPT = 1 + + +class Epoch(Enum): + INITIAL = 0 + ZERO_RTT = 1 + HANDSHAKE = 2 + ONE_RTT = 3 + + +class State(Enum): + CLIENT_HANDSHAKE_START = 0 + CLIENT_EXPECT_SERVER_HELLO = 1 + CLIENT_EXPECT_ENCRYPTED_EXTENSIONS = 2 + CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE = 3 + CLIENT_EXPECT_CERTIFICATE_CERTIFICATE = 4 + CLIENT_EXPECT_CERTIFICATE_VERIFY = 5 + CLIENT_EXPECT_FINISHED = 6 + CLIENT_POST_HANDSHAKE = 7 + + SERVER_EXPECT_CLIENT_HELLO = 8 + SERVER_EXPECT_FINISHED = 9 + SERVER_POST_HANDSHAKE = 10 + + +def hkdf_label(label: bytes, hash_value: bytes, length: int) -> bytes: + full_label = b"tls13 " + label + return ( + struct.pack("!HB", length, len(full_label)) + + full_label + + struct.pack("!B", len(hash_value)) + + hash_value + ) + + +def hkdf_expand_label( + algorithm: hashes.HashAlgorithm, + secret: bytes, + label: bytes, + hash_value: bytes, + length: int, +) -> bytes: + return HKDFExpand( + algorithm=algorithm, + length=length, + info=hkdf_label(label, hash_value, length), + backend=default_backend(), + ).derive(secret) + + +def hkdf_extract( + algorithm: hashes.HashAlgorithm, salt: bytes, key_material: bytes +) -> bytes: + h = hmac.HMAC(salt, algorithm, backend=default_backend()) + h.update(key_material) + return h.finalize() + + +def load_pem_private_key( + data: bytes, password: Optional[bytes] +) -> Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey]: + """ + Load a PEM-encoded private key. + """ + return serialization.load_pem_private_key( + data, password=password, backend=default_backend() + ) + + +def load_pem_x509_certificates(data: bytes) -> List[x509.Certificate]: + """ + Load a chain of PEM-encoded X509 certificates. + """ + boundary = b"-----END CERTIFICATE-----\n" + certificates = [] + for chunk in data.split(boundary): + if chunk: + certificates.append( + x509.load_pem_x509_certificate( + chunk + boundary, backend=default_backend() + ) + ) + return certificates + + +def openssl_assert(ok: bool, func: str) -> None: + if not ok: + lib.ERR_clear_error() + raise AlertInternalError("OpenSSL call to %s failed" % func) + + +def openssl_decode_string(charp) -> str: + return ffi.string(charp).decode("utf-8") if charp else "" + + +def openssl_encode_path(s: Optional[str]) -> Any: + if s is not None: + return os.fsencode(s) + return ffi.NULL + + +def cert_x509_ptr(certificate: x509.Certificate) -> Any: + """ + Accessor for private attribute. + """ + return getattr(certificate, "_x509") + + +def verify_certificate( + certificate: x509.Certificate, + chain: List[x509.Certificate] = [], + server_name: Optional[str] = None, + cadata: Optional[bytes] = None, + cafile: Optional[str] = None, + capath: Optional[str] = None, +) -> None: + # verify dates + now = utcnow() + if now < certificate.not_valid_before: + raise AlertCertificateExpired("Certificate is not valid yet") + if now > certificate.not_valid_after: + raise AlertCertificateExpired("Certificate is no longer valid") + + # verify subject + if server_name is not None: + subject = [] + subjectAltName: List[Tuple[str, str]] = [] + for attr in certificate.subject: + if attr.oid == x509.NameOID.COMMON_NAME: + subject.append((("commonName", attr.value),)) + for ext in certificate.extensions: + if isinstance(ext.value, x509.SubjectAlternativeName): + for name in ext.value: + if isinstance(name, x509.DNSName): + subjectAltName.append(("DNS", name.value)) + + try: + ssl.match_hostname( + {"subject": tuple(subject), "subjectAltName": tuple(subjectAltName)}, + server_name, + ) + except ssl.CertificateError as exc: + raise AlertBadCertificate("\n".join(exc.args)) from exc + + # verify certificate chain + store = lib.X509_STORE_new() + openssl_assert(store != ffi.NULL, "X509_store_new") + store = ffi.gc(store, lib.X509_STORE_free) + + # load default CAs + openssl_assert( + lib.X509_STORE_set_default_paths(store), "X509_STORE_set_default_paths" + ) + openssl_assert( + lib.X509_STORE_load_locations( + store, openssl_encode_path(certifi.where()), openssl_encode_path(None), + ), + "X509_STORE_load_locations", + ) + + # load extra CAs + if cadata is not None: + for cert in load_pem_x509_certificates(cadata): + openssl_assert( + lib.X509_STORE_add_cert(store, cert_x509_ptr(cert)), + "X509_STORE_add_cert", + ) + + if cafile is not None or capath is not None: + openssl_assert( + lib.X509_STORE_load_locations( + store, openssl_encode_path(cafile), openssl_encode_path(capath) + ), + "X509_STORE_load_locations", + ) + + chain_stack = lib.sk_X509_new_null() + openssl_assert(chain_stack != ffi.NULL, "sk_X509_new_null") + chain_stack = ffi.gc(chain_stack, lib.sk_X509_free) + for cert in chain: + openssl_assert( + lib.sk_X509_push(chain_stack, cert_x509_ptr(cert)), "sk_X509_push" + ) + + store_ctx = lib.X509_STORE_CTX_new() + openssl_assert(store_ctx != ffi.NULL, "X509_STORE_CTX_new") + store_ctx = ffi.gc(store_ctx, lib.X509_STORE_CTX_free) + openssl_assert( + lib.X509_STORE_CTX_init( + store_ctx, store, cert_x509_ptr(certificate), chain_stack + ), + "X509_STORE_CTX_init", + ) + + res = lib.X509_verify_cert(store_ctx) + if not res: + err = lib.X509_STORE_CTX_get_error(store_ctx) + err_str = openssl_decode_string(lib.X509_verify_cert_error_string(err)) + raise AlertBadCertificate(err_str) + + +class CipherSuite(IntEnum): + AES_128_GCM_SHA256 = 0x1301 + AES_256_GCM_SHA384 = 0x1302 + CHACHA20_POLY1305_SHA256 = 0x1303 + EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF + + +class CompressionMethod(IntEnum): + NULL = 0 + + +class ExtensionType(IntEnum): + SERVER_NAME = 0 + STATUS_REQUEST = 5 + SUPPORTED_GROUPS = 10 + SIGNATURE_ALGORITHMS = 13 + ALPN = 16 + COMPRESS_CERTIFICATE = 27 + PRE_SHARED_KEY = 41 + EARLY_DATA = 42 + SUPPORTED_VERSIONS = 43 + COOKIE = 44 + PSK_KEY_EXCHANGE_MODES = 45 + KEY_SHARE = 51 + QUIC_TRANSPORT_PARAMETERS = 65445 + ENCRYPTED_SERVER_NAME = 65486 + + +class Group(IntEnum): + SECP256R1 = 0x0017 + SECP384R1 = 0x0018 + SECP521R1 = 0x0019 + X25519 = 0x001D + X448 = 0x001E + GREASE = 0xAAAA + + +class HandshakeType(IntEnum): + CLIENT_HELLO = 1 + SERVER_HELLO = 2 + NEW_SESSION_TICKET = 4 + END_OF_EARLY_DATA = 5 + ENCRYPTED_EXTENSIONS = 8 + CERTIFICATE = 11 + CERTIFICATE_REQUEST = 13 + CERTIFICATE_VERIFY = 15 + FINISHED = 20 + KEY_UPDATE = 24 + COMPRESSED_CERTIFICATE = 25 + MESSAGE_HASH = 254 + + +class PskKeyExchangeMode(IntEnum): + PSK_KE = 0 + PSK_DHE_KE = 1 + + +class SignatureAlgorithm(IntEnum): + ECDSA_SECP256R1_SHA256 = 0x0403 + ECDSA_SECP384R1_SHA384 = 0x0503 + ECDSA_SECP521R1_SHA512 = 0x0603 + ED25519 = 0x0807 + ED448 = 0x0808 + RSA_PKCS1_SHA256 = 0x0401 + RSA_PKCS1_SHA384 = 0x0501 + RSA_PKCS1_SHA512 = 0x0601 + RSA_PSS_PSS_SHA256 = 0x0809 + RSA_PSS_PSS_SHA384 = 0x080A + RSA_PSS_PSS_SHA512 = 0x080B + RSA_PSS_RSAE_SHA256 = 0x0804 + RSA_PSS_RSAE_SHA384 = 0x0805 + RSA_PSS_RSAE_SHA512 = 0x0806 + + # legacy + RSA_PKCS1_SHA1 = 0x0201 + SHA1_DSA = 0x0202 + ECDSA_SHA1 = 0x0203 + + +# BLOCKS + + +@contextmanager +def pull_block(buf: Buffer, capacity: int) -> Generator: + length = 0 + for b in buf.pull_bytes(capacity): + length = (length << 8) | b + end = buf.tell() + length + yield length + assert buf.tell() == end + + +@contextmanager +def push_block(buf: Buffer, capacity: int) -> Generator: + """ + Context manager to push a variable-length block, with `capacity` bytes + to write the length. + """ + start = buf.tell() + capacity + buf.seek(start) + yield + end = buf.tell() + length = end - start + while capacity: + buf.seek(start - capacity) + buf.push_uint8((length >> (8 * (capacity - 1))) & 0xFF) + capacity -= 1 + buf.seek(end) + + +# LISTS + + +def pull_list(buf: Buffer, capacity: int, func: Callable[[], T]) -> List[T]: + """ + Pull a list of items. + """ + items = [] + with pull_block(buf, capacity) as length: + end = buf.tell() + length + while buf.tell() < end: + items.append(func()) + return items + + +def push_list( + buf: Buffer, capacity: int, func: Callable[[T], None], values: Sequence[T] +) -> None: + """ + Push a list of items. + """ + with push_block(buf, capacity): + for value in values: + func(value) + + +def pull_opaque(buf: Buffer, capacity: int) -> bytes: + """ + Pull an opaque value prefixed by a length. + """ + with pull_block(buf, capacity) as length: + return buf.pull_bytes(length) + + +def push_opaque(buf: Buffer, capacity: int, value: bytes) -> None: + """ + Push an opaque value prefix by a length. + """ + with push_block(buf, capacity): + buf.push_bytes(value) + + +@contextmanager +def push_extension(buf: Buffer, extension_type: int) -> Generator: + buf.push_uint16(extension_type) + with push_block(buf, 2): + yield + + +# KeyShareEntry + + +KeyShareEntry = Tuple[int, bytes] + + +def pull_key_share(buf: Buffer) -> KeyShareEntry: + group = buf.pull_uint16() + data = pull_opaque(buf, 2) + return (group, data) + + +def push_key_share(buf: Buffer, value: KeyShareEntry) -> None: + buf.push_uint16(value[0]) + push_opaque(buf, 2, value[1]) + + +# ALPN + + +def pull_alpn_protocol(buf: Buffer) -> str: + return pull_opaque(buf, 1).decode("ascii") + + +def push_alpn_protocol(buf: Buffer, protocol: str) -> None: + push_opaque(buf, 1, protocol.encode("ascii")) + + +# PRE SHARED KEY + +PskIdentity = Tuple[bytes, int] + + +def pull_psk_identity(buf: Buffer) -> PskIdentity: + identity = pull_opaque(buf, 2) + obfuscated_ticket_age = buf.pull_uint32() + return (identity, obfuscated_ticket_age) + + +def push_psk_identity(buf: Buffer, entry: PskIdentity) -> None: + push_opaque(buf, 2, entry[0]) + buf.push_uint32(entry[1]) + + +def pull_psk_binder(buf: Buffer) -> bytes: + return pull_opaque(buf, 1) + + +def push_psk_binder(buf: Buffer, binder: bytes) -> None: + push_opaque(buf, 1, binder) + + +# MESSAGES + +Extension = Tuple[int, bytes] + + +@dataclass +class OfferedPsks: + identities: List[PskIdentity] + binders: List[bytes] + + +@dataclass +class ClientHello: + random: bytes + session_id: bytes + cipher_suites: List[int] + compression_methods: List[int] + + # extensions + alpn_protocols: Optional[List[str]] = None + early_data: bool = False + key_share: Optional[List[KeyShareEntry]] = None + pre_shared_key: Optional[OfferedPsks] = None + psk_key_exchange_modes: Optional[List[int]] = None + server_name: Optional[str] = None + signature_algorithms: Optional[List[int]] = None + supported_groups: Optional[List[int]] = None + supported_versions: Optional[List[int]] = None + + other_extensions: List[Extension] = field(default_factory=list) + + +def pull_client_hello(buf: Buffer) -> ClientHello: + assert buf.pull_uint8() == HandshakeType.CLIENT_HELLO + with pull_block(buf, 3): + assert buf.pull_uint16() == TLS_VERSION_1_2 + client_random = buf.pull_bytes(32) + + hello = ClientHello( + random=client_random, + session_id=pull_opaque(buf, 1), + cipher_suites=pull_list(buf, 2, buf.pull_uint16), + compression_methods=pull_list(buf, 1, buf.pull_uint8), + ) + + # extensions + after_psk = False + + def pull_extension() -> None: + # pre_shared_key MUST be last + nonlocal after_psk + assert not after_psk + + extension_type = buf.pull_uint16() + extension_length = buf.pull_uint16() + if extension_type == ExtensionType.KEY_SHARE: + hello.key_share = pull_list(buf, 2, partial(pull_key_share, buf)) + elif extension_type == ExtensionType.SUPPORTED_VERSIONS: + hello.supported_versions = pull_list(buf, 1, buf.pull_uint16) + elif extension_type == ExtensionType.SIGNATURE_ALGORITHMS: + hello.signature_algorithms = pull_list(buf, 2, buf.pull_uint16) + elif extension_type == ExtensionType.SUPPORTED_GROUPS: + hello.supported_groups = pull_list(buf, 2, buf.pull_uint16) + elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES: + hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8) + elif extension_type == ExtensionType.SERVER_NAME: + with pull_block(buf, 2): + assert buf.pull_uint8() == 0 + hello.server_name = pull_opaque(buf, 2).decode("ascii") + elif extension_type == ExtensionType.ALPN: + hello.alpn_protocols = pull_list( + buf, 2, partial(pull_alpn_protocol, buf) + ) + elif extension_type == ExtensionType.EARLY_DATA: + hello.early_data = True + elif extension_type == ExtensionType.PRE_SHARED_KEY: + hello.pre_shared_key = OfferedPsks( + identities=pull_list(buf, 2, partial(pull_psk_identity, buf)), + binders=pull_list(buf, 2, partial(pull_psk_binder, buf)), + ) + after_psk = True + else: + hello.other_extensions.append( + (extension_type, buf.pull_bytes(extension_length)) + ) + + pull_list(buf, 2, pull_extension) + + return hello + + +def push_client_hello(buf: Buffer, hello: ClientHello) -> None: + buf.push_uint8(HandshakeType.CLIENT_HELLO) + with push_block(buf, 3): + buf.push_uint16(TLS_VERSION_1_2) + buf.push_bytes(hello.random) + push_opaque(buf, 1, hello.session_id) + push_list(buf, 2, buf.push_uint16, hello.cipher_suites) + push_list(buf, 1, buf.push_uint8, hello.compression_methods) + + # extensions + with push_block(buf, 2): + with push_extension(buf, ExtensionType.KEY_SHARE): + push_list(buf, 2, partial(push_key_share, buf), hello.key_share) + + with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS): + push_list(buf, 1, buf.push_uint16, hello.supported_versions) + + with push_extension(buf, ExtensionType.SIGNATURE_ALGORITHMS): + push_list(buf, 2, buf.push_uint16, hello.signature_algorithms) + + with push_extension(buf, ExtensionType.SUPPORTED_GROUPS): + push_list(buf, 2, buf.push_uint16, hello.supported_groups) + + if hello.psk_key_exchange_modes is not None: + with push_extension(buf, ExtensionType.PSK_KEY_EXCHANGE_MODES): + push_list(buf, 1, buf.push_uint8, hello.psk_key_exchange_modes) + + if hello.server_name is not None: + with push_extension(buf, ExtensionType.SERVER_NAME): + with push_block(buf, 2): + buf.push_uint8(0) + push_opaque(buf, 2, hello.server_name.encode("ascii")) + + if hello.alpn_protocols is not None: + with push_extension(buf, ExtensionType.ALPN): + push_list( + buf, 2, partial(push_alpn_protocol, buf), hello.alpn_protocols + ) + + for extension_type, extension_value in hello.other_extensions: + with push_extension(buf, extension_type): + buf.push_bytes(extension_value) + + if hello.early_data: + with push_extension(buf, ExtensionType.EARLY_DATA): + pass + + # pre_shared_key MUST be last + if hello.pre_shared_key is not None: + with push_extension(buf, ExtensionType.PRE_SHARED_KEY): + push_list( + buf, + 2, + partial(push_psk_identity, buf), + hello.pre_shared_key.identities, + ) + push_list( + buf, + 2, + partial(push_psk_binder, buf), + hello.pre_shared_key.binders, + ) + + +@dataclass +class ServerHello: + random: bytes + session_id: bytes + cipher_suite: int + compression_method: int + + # extensions + key_share: Optional[KeyShareEntry] = None + pre_shared_key: Optional[int] = None + supported_version: Optional[int] = None + other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) + + +def pull_server_hello(buf: Buffer) -> ServerHello: + assert buf.pull_uint8() == HandshakeType.SERVER_HELLO + with pull_block(buf, 3): + assert buf.pull_uint16() == TLS_VERSION_1_2 + server_random = buf.pull_bytes(32) + + hello = ServerHello( + random=server_random, + session_id=pull_opaque(buf, 1), + cipher_suite=buf.pull_uint16(), + compression_method=buf.pull_uint8(), + ) + + # extensions + def pull_extension() -> None: + extension_type = buf.pull_uint16() + extension_length = buf.pull_uint16() + if extension_type == ExtensionType.SUPPORTED_VERSIONS: + hello.supported_version = buf.pull_uint16() + elif extension_type == ExtensionType.KEY_SHARE: + hello.key_share = pull_key_share(buf) + elif extension_type == ExtensionType.PRE_SHARED_KEY: + hello.pre_shared_key = buf.pull_uint16() + else: + hello.other_extensions.append( + (extension_type, buf.pull_bytes(extension_length)) + ) + + pull_list(buf, 2, pull_extension) + + return hello + + +def push_server_hello(buf: Buffer, hello: ServerHello) -> None: + buf.push_uint8(HandshakeType.SERVER_HELLO) + with push_block(buf, 3): + buf.push_uint16(TLS_VERSION_1_2) + buf.push_bytes(hello.random) + + push_opaque(buf, 1, hello.session_id) + buf.push_uint16(hello.cipher_suite) + buf.push_uint8(hello.compression_method) + + # extensions + with push_block(buf, 2): + if hello.supported_version is not None: + with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS): + buf.push_uint16(hello.supported_version) + + if hello.key_share is not None: + with push_extension(buf, ExtensionType.KEY_SHARE): + push_key_share(buf, hello.key_share) + + if hello.pre_shared_key is not None: + with push_extension(buf, ExtensionType.PRE_SHARED_KEY): + buf.push_uint16(hello.pre_shared_key) + + for extension_type, extension_value in hello.other_extensions: + with push_extension(buf, extension_type): + buf.push_bytes(extension_value) + + +@dataclass +class NewSessionTicket: + ticket_lifetime: int = 0 + ticket_age_add: int = 0 + ticket_nonce: bytes = b"" + ticket: bytes = b"" + + # extensions + max_early_data_size: Optional[int] = None + other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) + + +def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket: + new_session_ticket = NewSessionTicket() + + assert buf.pull_uint8() == HandshakeType.NEW_SESSION_TICKET + with pull_block(buf, 3): + new_session_ticket.ticket_lifetime = buf.pull_uint32() + new_session_ticket.ticket_age_add = buf.pull_uint32() + new_session_ticket.ticket_nonce = pull_opaque(buf, 1) + new_session_ticket.ticket = pull_opaque(buf, 2) + + def pull_extension() -> None: + extension_type = buf.pull_uint16() + extension_length = buf.pull_uint16() + if extension_type == ExtensionType.EARLY_DATA: + new_session_ticket.max_early_data_size = buf.pull_uint32() + else: + new_session_ticket.other_extensions.append( + (extension_type, buf.pull_bytes(extension_length)) + ) + + pull_list(buf, 2, pull_extension) + + return new_session_ticket + + +def push_new_session_ticket(buf: Buffer, new_session_ticket: NewSessionTicket) -> None: + buf.push_uint8(HandshakeType.NEW_SESSION_TICKET) + with push_block(buf, 3): + buf.push_uint32(new_session_ticket.ticket_lifetime) + buf.push_uint32(new_session_ticket.ticket_age_add) + push_opaque(buf, 1, new_session_ticket.ticket_nonce) + push_opaque(buf, 2, new_session_ticket.ticket) + + with push_block(buf, 2): + if new_session_ticket.max_early_data_size is not None: + with push_extension(buf, ExtensionType.EARLY_DATA): + buf.push_uint32(new_session_ticket.max_early_data_size) + + for extension_type, extension_value in new_session_ticket.other_extensions: + with push_extension(buf, extension_type): + buf.push_bytes(extension_value) + + +@dataclass +class EncryptedExtensions: + alpn_protocol: Optional[str] = None + early_data: bool = False + + other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) + + +def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions: + extensions = EncryptedExtensions() + + assert buf.pull_uint8() == HandshakeType.ENCRYPTED_EXTENSIONS + with pull_block(buf, 3): + + def pull_extension() -> None: + extension_type = buf.pull_uint16() + extension_length = buf.pull_uint16() + if extension_type == ExtensionType.ALPN: + extensions.alpn_protocol = pull_list( + buf, 2, partial(pull_alpn_protocol, buf) + )[0] + elif extension_type == ExtensionType.EARLY_DATA: + extensions.early_data = True + else: + extensions.other_extensions.append( + (extension_type, buf.pull_bytes(extension_length)) + ) + + pull_list(buf, 2, pull_extension) + + return extensions + + +def push_encrypted_extensions(buf: Buffer, extensions: EncryptedExtensions) -> None: + buf.push_uint8(HandshakeType.ENCRYPTED_EXTENSIONS) + with push_block(buf, 3): + with push_block(buf, 2): + if extensions.alpn_protocol is not None: + with push_extension(buf, ExtensionType.ALPN): + push_list( + buf, + 2, + partial(push_alpn_protocol, buf), + [extensions.alpn_protocol], + ) + + if extensions.early_data: + with push_extension(buf, ExtensionType.EARLY_DATA): + pass + + for extension_type, extension_value in extensions.other_extensions: + with push_extension(buf, extension_type): + buf.push_bytes(extension_value) + + +CertificateEntry = Tuple[bytes, bytes] + + +@dataclass +class Certificate: + request_context: bytes = b"" + certificates: List[CertificateEntry] = field(default_factory=list) + + +def pull_certificate(buf: Buffer) -> Certificate: + certificate = Certificate() + + assert buf.pull_uint8() == HandshakeType.CERTIFICATE + with pull_block(buf, 3): + certificate.request_context = pull_opaque(buf, 1) + + def pull_certificate_entry(buf: Buffer) -> CertificateEntry: + data = pull_opaque(buf, 3) + extensions = pull_opaque(buf, 2) + return (data, extensions) + + certificate.certificates = pull_list( + buf, 3, partial(pull_certificate_entry, buf) + ) + + return certificate + + +def push_certificate(buf: Buffer, certificate: Certificate) -> None: + buf.push_uint8(HandshakeType.CERTIFICATE) + with push_block(buf, 3): + push_opaque(buf, 1, certificate.request_context) + + def push_certificate_entry(buf: Buffer, entry: CertificateEntry) -> None: + push_opaque(buf, 3, entry[0]) + push_opaque(buf, 2, entry[1]) + + push_list( + buf, 3, partial(push_certificate_entry, buf), certificate.certificates + ) + + +@dataclass +class CertificateVerify: + algorithm: int + signature: bytes + + +def pull_certificate_verify(buf: Buffer) -> CertificateVerify: + assert buf.pull_uint8() == HandshakeType.CERTIFICATE_VERIFY + with pull_block(buf, 3): + algorithm = buf.pull_uint16() + signature = pull_opaque(buf, 2) + + return CertificateVerify(algorithm=algorithm, signature=signature) + + +def push_certificate_verify(buf: Buffer, verify: CertificateVerify) -> None: + buf.push_uint8(HandshakeType.CERTIFICATE_VERIFY) + with push_block(buf, 3): + buf.push_uint16(verify.algorithm) + push_opaque(buf, 2, verify.signature) + + +@dataclass +class Finished: + verify_data: bytes = b"" + + +def pull_finished(buf: Buffer) -> Finished: + finished = Finished() + + assert buf.pull_uint8() == HandshakeType.FINISHED + finished.verify_data = pull_opaque(buf, 3) + + return finished + + +def push_finished(buf: Buffer, finished: Finished) -> None: + buf.push_uint8(HandshakeType.FINISHED) + push_opaque(buf, 3, finished.verify_data) + + +# CONTEXT + + +class KeySchedule: + def __init__(self, cipher_suite: CipherSuite): + self.algorithm = cipher_suite_hash(cipher_suite) + self.cipher_suite = cipher_suite + self.generation = 0 + self.hash = hashes.Hash(self.algorithm, default_backend()) + self.hash_empty_value = self.hash.copy().finalize() + self.secret = bytes(self.algorithm.digest_size) + + def certificate_verify_data(self, context_string: bytes) -> bytes: + return b" " * 64 + context_string + b"\x00" + self.hash.copy().finalize() + + def finished_verify_data(self, secret: bytes) -> bytes: + hmac_key = hkdf_expand_label( + algorithm=self.algorithm, + secret=secret, + label=b"finished", + hash_value=b"", + length=self.algorithm.digest_size, + ) + + h = hmac.HMAC(hmac_key, algorithm=self.algorithm, backend=default_backend()) + h.update(self.hash.copy().finalize()) + return h.finalize() + + def derive_secret(self, label: bytes) -> bytes: + return hkdf_expand_label( + algorithm=self.algorithm, + secret=self.secret, + label=label, + hash_value=self.hash.copy().finalize(), + length=self.algorithm.digest_size, + ) + + def extract(self, key_material: Optional[bytes] = None) -> None: + if key_material is None: + key_material = bytes(self.algorithm.digest_size) + + if self.generation: + self.secret = hkdf_expand_label( + algorithm=self.algorithm, + secret=self.secret, + label=b"derived", + hash_value=self.hash_empty_value, + length=self.algorithm.digest_size, + ) + + self.generation += 1 + self.secret = hkdf_extract( + algorithm=self.algorithm, salt=self.secret, key_material=key_material + ) + + def update_hash(self, data: bytes) -> None: + self.hash.update(data) + + +class KeyScheduleProxy: + def __init__(self, cipher_suites: List[CipherSuite]): + self.__schedules = dict(map(lambda c: (c, KeySchedule(c)), cipher_suites)) + + def extract(self, key_material: Optional[bytes] = None) -> None: + for k in self.__schedules.values(): + k.extract(key_material) + + def select(self, cipher_suite: CipherSuite) -> KeySchedule: + return self.__schedules[cipher_suite] + + def update_hash(self, data: bytes) -> None: + for k in self.__schedules.values(): + k.update_hash(data) + + +CIPHER_SUITES = { + CipherSuite.AES_128_GCM_SHA256: hashes.SHA256, + CipherSuite.AES_256_GCM_SHA384: hashes.SHA384, + CipherSuite.CHACHA20_POLY1305_SHA256: hashes.SHA256, +} + +SIGNATURE_ALGORITHMS: Dict = { + SignatureAlgorithm.ECDSA_SECP256R1_SHA256: (None, hashes.SHA256), + SignatureAlgorithm.ECDSA_SECP384R1_SHA384: (None, hashes.SHA384), + SignatureAlgorithm.ECDSA_SECP521R1_SHA512: (None, hashes.SHA512), + SignatureAlgorithm.RSA_PKCS1_SHA1: (padding.PKCS1v15, hashes.SHA1), + SignatureAlgorithm.RSA_PKCS1_SHA256: (padding.PKCS1v15, hashes.SHA256), + SignatureAlgorithm.RSA_PKCS1_SHA384: (padding.PKCS1v15, hashes.SHA384), + SignatureAlgorithm.RSA_PKCS1_SHA512: (padding.PKCS1v15, hashes.SHA512), + SignatureAlgorithm.RSA_PSS_RSAE_SHA256: (padding.PSS, hashes.SHA256), + SignatureAlgorithm.RSA_PSS_RSAE_SHA384: (padding.PSS, hashes.SHA384), + SignatureAlgorithm.RSA_PSS_RSAE_SHA512: (padding.PSS, hashes.SHA512), +} + +GROUP_TO_CURVE: Dict = { + Group.SECP256R1: ec.SECP256R1, + Group.SECP384R1: ec.SECP384R1, + Group.SECP521R1: ec.SECP521R1, +} +CURVE_TO_GROUP = dict((v, k) for k, v in GROUP_TO_CURVE.items()) + + +def cipher_suite_hash(cipher_suite: CipherSuite) -> hashes.HashAlgorithm: + return CIPHER_SUITES[cipher_suite]() + + +def decode_public_key( + key_share: KeyShareEntry, +) -> Union[ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey, None]: + if key_share[0] == Group.X25519: + return x25519.X25519PublicKey.from_public_bytes(key_share[1]) + elif key_share[0] == Group.X448: + return x448.X448PublicKey.from_public_bytes(key_share[1]) + elif key_share[0] in GROUP_TO_CURVE: + return ec.EllipticCurvePublicKey.from_encoded_point( + GROUP_TO_CURVE[key_share[0]](), key_share[1] + ) + else: + return None + + +def encode_public_key( + public_key: Union[ + ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey + ] +) -> KeyShareEntry: + if isinstance(public_key, x25519.X25519PublicKey): + return (Group.X25519, public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)) + elif isinstance(public_key, x448.X448PublicKey): + return (Group.X448, public_key.public_bytes(Encoding.Raw, PublicFormat.Raw)) + return ( + CURVE_TO_GROUP[public_key.curve.__class__], + public_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint), + ) + + +def negotiate( + supported: List[T], offered: Optional[List[Any]], exc: Optional[Alert] = None +) -> T: + if offered is not None: + for c in supported: + if c in offered: + return c + + if exc is not None: + raise exc + return None + + +def signature_algorithm_params( + signature_algorithm: int, +) -> Union[Tuple[ec.ECDSA], Tuple[padding.AsymmetricPadding, hashes.HashAlgorithm]]: + padding_cls, algorithm_cls = SIGNATURE_ALGORITHMS[signature_algorithm] + algorithm = algorithm_cls() + if padding_cls is None: + return (ec.ECDSA(algorithm),) + elif padding_cls == padding.PSS: + padding_obj = padding_cls( + mgf=padding.MGF1(algorithm), salt_length=algorithm.digest_size + ) + else: + padding_obj = padding_cls() + return padding_obj, algorithm + + +@contextmanager +def push_message( + key_schedule: Union[KeySchedule, KeyScheduleProxy], buf: Buffer +) -> Generator: + hash_start = buf.tell() + yield + key_schedule.update_hash(buf.data_slice(hash_start, buf.tell())) + + +# callback types + + +@dataclass +class SessionTicket: + """ + A TLS session ticket for session resumption. + """ + + age_add: int + cipher_suite: CipherSuite + not_valid_after: datetime.datetime + not_valid_before: datetime.datetime + resumption_secret: bytes + server_name: str + ticket: bytes + + max_early_data_size: Optional[int] = None + other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) + + @property + def is_valid(self) -> bool: + now = utcnow() + return now >= self.not_valid_before and now <= self.not_valid_after + + @property + def obfuscated_age(self) -> int: + age = int((utcnow() - self.not_valid_before).total_seconds()) + return (age + self.age_add) % (1 << 32) + + +AlpnHandler = Callable[[str], None] +SessionTicketFetcher = Callable[[bytes], Optional[SessionTicket]] +SessionTicketHandler = Callable[[SessionTicket], None] + + +class Context: + def __init__( + self, + is_client: bool, + alpn_protocols: Optional[List[str]] = None, + cadata: Optional[bytes] = None, + cafile: Optional[str] = None, + capath: Optional[str] = None, + logger: Optional[Union[logging.Logger, logging.LoggerAdapter]] = None, + max_early_data: Optional[int] = None, + server_name: Optional[str] = None, + verify_mode: Optional[int] = None, + ): + # configuration + self._alpn_protocols = alpn_protocols + self._cadata = cadata + self._cafile = cafile + self._capath = capath + self.certificate: Optional[x509.Certificate] = None + self.certificate_chain: List[x509.Certificate] = [] + self.certificate_private_key: Optional[ + Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey] + ] = None + self.handshake_extensions: List[Extension] = [] + self._max_early_data = max_early_data + self.session_ticket: Optional[SessionTicket] = None + self._server_name = server_name + if verify_mode is not None: + self._verify_mode = verify_mode + else: + self._verify_mode = ssl.CERT_REQUIRED if is_client else ssl.CERT_NONE + + # callbacks + self.alpn_cb: Optional[AlpnHandler] = None + self.get_session_ticket_cb: Optional[SessionTicketFetcher] = None + self.new_session_ticket_cb: Optional[SessionTicketHandler] = None + self.update_traffic_key_cb: Callable[ + [Direction, Epoch, CipherSuite, bytes], None + ] = lambda d, e, c, s: None + + # supported parameters + self._cipher_suites = [ + CipherSuite.AES_256_GCM_SHA384, + CipherSuite.AES_128_GCM_SHA256, + CipherSuite.CHACHA20_POLY1305_SHA256, + ] + self._compression_methods: List[int] = [CompressionMethod.NULL] + self._psk_key_exchange_modes: List[int] = [PskKeyExchangeMode.PSK_DHE_KE] + self._signature_algorithms: List[int] = [ + SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.RSA_PKCS1_SHA256, + SignatureAlgorithm.RSA_PKCS1_SHA1, + ] + self._supported_groups = [Group.SECP256R1] + if default_backend().x25519_supported(): + self._supported_groups.append(Group.X25519) + if default_backend().x448_supported(): + self._supported_groups.append(Group.X448) + self._supported_versions = [TLS_VERSION_1_3] + + # state + self.alpn_negotiated: Optional[str] = None + self.early_data_accepted = False + self.key_schedule: Optional[KeySchedule] = None + self.received_extensions: Optional[List[Extension]] = None + self._key_schedule_psk: Optional[KeySchedule] = None + self._key_schedule_proxy: Optional[KeyScheduleProxy] = None + self._new_session_ticket: Optional[NewSessionTicket] = None + self._peer_certificate: Optional[x509.Certificate] = None + self._peer_certificate_chain: List[x509.Certificate] = [] + self._receive_buffer = b"" + self._session_resumed = False + self._enc_key: Optional[bytes] = None + self._dec_key: Optional[bytes] = None + self.__logger = logger + + self._ec_private_key: Optional[ec.EllipticCurvePrivateKey] = None + self._x25519_private_key: Optional[x25519.X25519PrivateKey] = None + self._x448_private_key: Optional[x448.X448PrivateKey] = None + + if is_client: + self.client_random = os.urandom(32) + self.session_id = os.urandom(32) + self.state = State.CLIENT_HANDSHAKE_START + else: + self.client_random = None + self.session_id = None + self.state = State.SERVER_EXPECT_CLIENT_HELLO + + @property + def session_resumed(self) -> bool: + """ + Returns True if session resumption was successfully used. + """ + return self._session_resumed + + def handle_message( + self, input_data: bytes, output_buf: Dict[Epoch, Buffer] + ) -> None: + if self.state == State.CLIENT_HANDSHAKE_START: + self._client_send_hello(output_buf[Epoch.INITIAL]) + return + + self._receive_buffer += input_data + while len(self._receive_buffer) >= 4: + # determine message length + message_type = self._receive_buffer[0] + message_length = 0 + for b in self._receive_buffer[1:4]: + message_length = (message_length << 8) | b + message_length += 4 + + # check message is complete + if len(self._receive_buffer) < message_length: + break + message = self._receive_buffer[:message_length] + self._receive_buffer = self._receive_buffer[message_length:] + + input_buf = Buffer(data=message) + + # client states + + if self.state == State.CLIENT_EXPECT_SERVER_HELLO: + if message_type == HandshakeType.SERVER_HELLO: + self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL]) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS: + if message_type == HandshakeType.ENCRYPTED_EXTENSIONS: + self._client_handle_encrypted_extensions(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE: + if message_type == HandshakeType.CERTIFICATE: + self._client_handle_certificate(input_buf) + else: + # FIXME: handle certificate request + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY: + if message_type == HandshakeType.CERTIFICATE_VERIFY: + self._client_handle_certificate_verify(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_FINISHED: + if message_type == HandshakeType.FINISHED: + self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE]) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_POST_HANDSHAKE: + if message_type == HandshakeType.NEW_SESSION_TICKET: + self._client_handle_new_session_ticket(input_buf) + else: + raise AlertUnexpectedMessage + + # server states + + elif self.state == State.SERVER_EXPECT_CLIENT_HELLO: + if message_type == HandshakeType.CLIENT_HELLO: + self._server_handle_hello( + input_buf, + output_buf[Epoch.INITIAL], + output_buf[Epoch.HANDSHAKE], + output_buf[Epoch.ONE_RTT], + ) + else: + raise AlertUnexpectedMessage + elif self.state == State.SERVER_EXPECT_FINISHED: + if message_type == HandshakeType.FINISHED: + self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT]) + else: + raise AlertUnexpectedMessage + elif self.state == State.SERVER_POST_HANDSHAKE: + raise AlertUnexpectedMessage + + assert input_buf.eof() + + def _build_session_ticket( + self, new_session_ticket: NewSessionTicket + ) -> SessionTicket: + resumption_master_secret = self.key_schedule.derive_secret(b"res master") + resumption_secret = hkdf_expand_label( + algorithm=self.key_schedule.algorithm, + secret=resumption_master_secret, + label=b"resumption", + hash_value=new_session_ticket.ticket_nonce, + length=self.key_schedule.algorithm.digest_size, + ) + + timestamp = utcnow() + return SessionTicket( + age_add=new_session_ticket.ticket_age_add, + cipher_suite=self.key_schedule.cipher_suite, + max_early_data_size=new_session_ticket.max_early_data_size, + not_valid_after=timestamp + + datetime.timedelta(seconds=new_session_ticket.ticket_lifetime), + not_valid_before=timestamp, + other_extensions=self.handshake_extensions, + resumption_secret=resumption_secret, + server_name=self._server_name, + ticket=new_session_ticket.ticket, + ) + + def _client_send_hello(self, output_buf: Buffer) -> None: + key_share: List[KeyShareEntry] = [] + supported_groups: List[int] = [] + + for group in self._supported_groups: + if group == Group.SECP256R1: + self._ec_private_key = ec.generate_private_key( + GROUP_TO_CURVE[Group.SECP256R1](), default_backend() + ) + key_share.append(encode_public_key(self._ec_private_key.public_key())) + supported_groups.append(Group.SECP256R1) + elif group == Group.X25519: + self._x25519_private_key = x25519.X25519PrivateKey.generate() + key_share.append( + encode_public_key(self._x25519_private_key.public_key()) + ) + supported_groups.append(Group.X25519) + elif group == Group.X448: + self._x448_private_key = x448.X448PrivateKey.generate() + key_share.append(encode_public_key(self._x448_private_key.public_key())) + supported_groups.append(Group.X448) + elif group == Group.GREASE: + key_share.append((Group.GREASE, b"\x00")) + supported_groups.append(Group.GREASE) + + assert len(key_share), "no key share entries" + + hello = ClientHello( + random=self.client_random, + session_id=self.session_id, + cipher_suites=[int(x) for x in self._cipher_suites], + compression_methods=self._compression_methods, + alpn_protocols=self._alpn_protocols, + key_share=key_share, + psk_key_exchange_modes=self._psk_key_exchange_modes + if (self.session_ticket or self.new_session_ticket_cb is not None) + else None, + server_name=self._server_name, + signature_algorithms=self._signature_algorithms, + supported_groups=supported_groups, + supported_versions=self._supported_versions, + other_extensions=self.handshake_extensions, + ) + + # PSK + if self.session_ticket and self.session_ticket.is_valid: + self._key_schedule_psk = KeySchedule(self.session_ticket.cipher_suite) + self._key_schedule_psk.extract(self.session_ticket.resumption_secret) + binder_key = self._key_schedule_psk.derive_secret(b"res binder") + binder_length = self._key_schedule_psk.algorithm.digest_size + + # update hello + if self.session_ticket.max_early_data_size is not None: + hello.early_data = True + hello.pre_shared_key = OfferedPsks( + identities=[ + (self.session_ticket.ticket, self.session_ticket.obfuscated_age) + ], + binders=[bytes(binder_length)], + ) + + # serialize hello without binder + tmp_buf = Buffer(capacity=1024) + push_client_hello(tmp_buf, hello) + + # calculate binder + hash_offset = tmp_buf.tell() - binder_length - 3 + self._key_schedule_psk.update_hash(tmp_buf.data_slice(0, hash_offset)) + binder = self._key_schedule_psk.finished_verify_data(binder_key) + hello.pre_shared_key.binders[0] = binder + self._key_schedule_psk.update_hash( + tmp_buf.data_slice(hash_offset, hash_offset + 3) + binder + ) + + # calculate early data key + if hello.early_data: + early_key = self._key_schedule_psk.derive_secret(b"c e traffic") + self.update_traffic_key_cb( + Direction.ENCRYPT, + Epoch.ZERO_RTT, + self._key_schedule_psk.cipher_suite, + early_key, + ) + + self._key_schedule_proxy = KeyScheduleProxy(self._cipher_suites) + self._key_schedule_proxy.extract(None) + + with push_message(self._key_schedule_proxy, output_buf): + push_client_hello(output_buf, hello) + + self._set_state(State.CLIENT_EXPECT_SERVER_HELLO) + + def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: + peer_hello = pull_server_hello(input_buf) + + cipher_suite = negotiate( + self._cipher_suites, + [peer_hello.cipher_suite], + AlertHandshakeFailure("Unsupported cipher suite"), + ) + assert peer_hello.compression_method in self._compression_methods + assert peer_hello.supported_version in self._supported_versions + + # select key schedule + if peer_hello.pre_shared_key is not None: + if ( + self._key_schedule_psk is None + or peer_hello.pre_shared_key != 0 + or cipher_suite != self._key_schedule_psk.cipher_suite + ): + raise AlertIllegalParameter + self.key_schedule = self._key_schedule_psk + self._session_resumed = True + else: + self.key_schedule = self._key_schedule_proxy.select(cipher_suite) + self._key_schedule_psk = None + self._key_schedule_proxy = None + + # perform key exchange + peer_public_key = decode_public_key(peer_hello.key_share) + shared_key: Optional[bytes] = None + if ( + isinstance(peer_public_key, x25519.X25519PublicKey) + and self._x25519_private_key is not None + ): + shared_key = self._x25519_private_key.exchange(peer_public_key) + elif ( + isinstance(peer_public_key, x448.X448PublicKey) + and self._x448_private_key is not None + ): + shared_key = self._x448_private_key.exchange(peer_public_key) + elif ( + isinstance(peer_public_key, ec.EllipticCurvePublicKey) + and self._ec_private_key is not None + and self._ec_private_key.public_key().curve.__class__ + == peer_public_key.curve.__class__ + ): + shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key) + assert shared_key is not None + + self.key_schedule.update_hash(input_buf.data) + self.key_schedule.extract(shared_key) + + self._setup_traffic_protection( + Direction.DECRYPT, Epoch.HANDSHAKE, b"s hs traffic" + ) + + self._set_state(State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS) + + def _client_handle_encrypted_extensions(self, input_buf: Buffer) -> None: + encrypted_extensions = pull_encrypted_extensions(input_buf) + + self.alpn_negotiated = encrypted_extensions.alpn_protocol + self.early_data_accepted = encrypted_extensions.early_data + self.received_extensions = encrypted_extensions.other_extensions + if self.alpn_cb: + self.alpn_cb(self.alpn_negotiated) + + self._setup_traffic_protection( + Direction.ENCRYPT, Epoch.HANDSHAKE, b"c hs traffic" + ) + self.key_schedule.update_hash(input_buf.data) + + # if the server accepted our PSK we are done, other we want its certificate + if self._session_resumed: + self._set_state(State.CLIENT_EXPECT_FINISHED) + else: + self._set_state(State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE) + + def _client_handle_certificate(self, input_buf: Buffer) -> None: + certificate = pull_certificate(input_buf) + + self._peer_certificate = x509.load_der_x509_certificate( + certificate.certificates[0][0], backend=default_backend() + ) + self._peer_certificate_chain = [ + x509.load_der_x509_certificate( + certificate.certificates[i][0], backend=default_backend() + ) + for i in range(1, len(certificate.certificates)) + ] + + self.key_schedule.update_hash(input_buf.data) + + self._set_state(State.CLIENT_EXPECT_CERTIFICATE_VERIFY) + + def _client_handle_certificate_verify(self, input_buf: Buffer) -> None: + verify = pull_certificate_verify(input_buf) + + assert verify.algorithm in self._signature_algorithms + + # check signature + try: + self._peer_certificate.public_key().verify( + verify.signature, + self.key_schedule.certificate_verify_data( + b"TLS 1.3, server CertificateVerify" + ), + *signature_algorithm_params(verify.algorithm), + ) + except InvalidSignature: + raise AlertDecryptError + + # check certificate + if self._verify_mode != ssl.CERT_NONE: + verify_certificate( + cadata=self._cadata, + cafile=self._cafile, + capath=self._capath, + certificate=self._peer_certificate, + chain=self._peer_certificate_chain, + server_name=self._server_name, + ) + + self.key_schedule.update_hash(input_buf.data) + + self._set_state(State.CLIENT_EXPECT_FINISHED) + + def _client_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None: + finished = pull_finished(input_buf) + + # check verify data + expected_verify_data = self.key_schedule.finished_verify_data(self._dec_key) + if finished.verify_data != expected_verify_data: + raise AlertDecryptError + self.key_schedule.update_hash(input_buf.data) + + # prepare traffic keys + assert self.key_schedule.generation == 2 + self.key_schedule.extract(None) + self._setup_traffic_protection( + Direction.DECRYPT, Epoch.ONE_RTT, b"s ap traffic" + ) + next_enc_key = self.key_schedule.derive_secret(b"c ap traffic") + + # send finished + with push_message(self.key_schedule, output_buf): + push_finished( + output_buf, + Finished( + verify_data=self.key_schedule.finished_verify_data(self._enc_key) + ), + ) + + # commit traffic key + self._enc_key = next_enc_key + self.update_traffic_key_cb( + Direction.ENCRYPT, + Epoch.ONE_RTT, + self.key_schedule.cipher_suite, + self._enc_key, + ) + + self._set_state(State.CLIENT_POST_HANDSHAKE) + + def _client_handle_new_session_ticket(self, input_buf: Buffer) -> None: + new_session_ticket = pull_new_session_ticket(input_buf) + + # notify application + if self.new_session_ticket_cb is not None: + ticket = self._build_session_ticket(new_session_ticket) + self.new_session_ticket_cb(ticket) + + def _server_handle_hello( + self, + input_buf: Buffer, + initial_buf: Buffer, + handshake_buf: Buffer, + onertt_buf: Buffer, + ) -> None: + peer_hello = pull_client_hello(input_buf) + + # determine applicable signature algorithms + signature_algorithms: List[SignatureAlgorithm] = [] + if isinstance(self.certificate_private_key, rsa.RSAPrivateKey): + signature_algorithms = [ + SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.RSA_PKCS1_SHA256, + SignatureAlgorithm.RSA_PKCS1_SHA1, + ] + elif isinstance( + self.certificate_private_key, ec.EllipticCurvePrivateKey + ) and isinstance(self.certificate_private_key.curve, ec.SECP256R1): + signature_algorithms = [SignatureAlgorithm.ECDSA_SECP256R1_SHA256] + + # negotiate parameters + cipher_suite = negotiate( + self._cipher_suites, + peer_hello.cipher_suites, + AlertHandshakeFailure("No supported cipher suite"), + ) + compression_method = negotiate( + self._compression_methods, + peer_hello.compression_methods, + AlertHandshakeFailure("No supported compression method"), + ) + psk_key_exchange_mode = negotiate( + self._psk_key_exchange_modes, peer_hello.psk_key_exchange_modes + ) + signature_algorithm = negotiate( + signature_algorithms, + peer_hello.signature_algorithms, + AlertHandshakeFailure("No supported signature algorithm"), + ) + supported_version = negotiate( + self._supported_versions, + peer_hello.supported_versions, + AlertProtocolVersion("No supported protocol version"), + ) + + # negotiate ALPN + if self._alpn_protocols is not None: + self.alpn_negotiated = negotiate( + self._alpn_protocols, + peer_hello.alpn_protocols, + AlertHandshakeFailure("No common ALPN protocols"), + ) + if self.alpn_cb: + self.alpn_cb(self.alpn_negotiated) + + self.client_random = peer_hello.random + self.server_random = os.urandom(32) + self.session_id = peer_hello.session_id + self.received_extensions = peer_hello.other_extensions + + # select key schedule + pre_shared_key = None + if ( + self.get_session_ticket_cb is not None + and psk_key_exchange_mode is not None + and peer_hello.pre_shared_key is not None + and len(peer_hello.pre_shared_key.identities) == 1 + and len(peer_hello.pre_shared_key.binders) == 1 + ): + # ask application to find session ticket + identity = peer_hello.pre_shared_key.identities[0] + session_ticket = self.get_session_ticket_cb(identity[0]) + + # validate session ticket + if ( + session_ticket is not None + and session_ticket.is_valid + and session_ticket.cipher_suite == cipher_suite + ): + self.key_schedule = KeySchedule(cipher_suite) + self.key_schedule.extract(session_ticket.resumption_secret) + + binder_key = self.key_schedule.derive_secret(b"res binder") + binder_length = self.key_schedule.algorithm.digest_size + + hash_offset = input_buf.tell() - binder_length - 3 + binder = input_buf.data_slice( + hash_offset + 3, hash_offset + 3 + binder_length + ) + + self.key_schedule.update_hash(input_buf.data_slice(0, hash_offset)) + expected_binder = self.key_schedule.finished_verify_data(binder_key) + + if binder != expected_binder: + raise AlertHandshakeFailure("PSK validation failed") + + self.key_schedule.update_hash( + input_buf.data_slice(hash_offset, hash_offset + 3 + binder_length) + ) + self._session_resumed = True + + # calculate early data key + if peer_hello.early_data: + early_key = self.key_schedule.derive_secret(b"c e traffic") + self.early_data_accepted = True + self.update_traffic_key_cb( + Direction.DECRYPT, + Epoch.ZERO_RTT, + self.key_schedule.cipher_suite, + early_key, + ) + + pre_shared_key = 0 + + # if PSK is not used, initialize key schedule + if pre_shared_key is None: + self.key_schedule = KeySchedule(cipher_suite) + self.key_schedule.extract(None) + self.key_schedule.update_hash(input_buf.data) + + # perform key exchange + public_key: Union[ + ec.EllipticCurvePublicKey, x25519.X25519PublicKey, x448.X448PublicKey + ] + shared_key: Optional[bytes] = None + for key_share in peer_hello.key_share: + peer_public_key = decode_public_key(key_share) + if isinstance(peer_public_key, x25519.X25519PublicKey): + self._x25519_private_key = x25519.X25519PrivateKey.generate() + public_key = self._x25519_private_key.public_key() + shared_key = self._x25519_private_key.exchange(peer_public_key) + break + elif isinstance(peer_public_key, x448.X448PublicKey): + self._x448_private_key = x448.X448PrivateKey.generate() + public_key = self._x448_private_key.public_key() + shared_key = self._x448_private_key.exchange(peer_public_key) + break + elif isinstance(peer_public_key, ec.EllipticCurvePublicKey): + self._ec_private_key = ec.generate_private_key( + GROUP_TO_CURVE[key_share[0]](), default_backend() + ) + public_key = self._ec_private_key.public_key() + shared_key = self._ec_private_key.exchange(ec.ECDH(), peer_public_key) + break + assert shared_key is not None + + # send hello + hello = ServerHello( + random=self.server_random, + session_id=self.session_id, + cipher_suite=cipher_suite, + compression_method=compression_method, + key_share=encode_public_key(public_key), + pre_shared_key=pre_shared_key, + supported_version=supported_version, + ) + with push_message(self.key_schedule, initial_buf): + push_server_hello(initial_buf, hello) + self.key_schedule.extract(shared_key) + + self._setup_traffic_protection( + Direction.ENCRYPT, Epoch.HANDSHAKE, b"s hs traffic" + ) + self._setup_traffic_protection( + Direction.DECRYPT, Epoch.HANDSHAKE, b"c hs traffic" + ) + + # send encrypted extensions + with push_message(self.key_schedule, handshake_buf): + push_encrypted_extensions( + handshake_buf, + EncryptedExtensions( + alpn_protocol=self.alpn_negotiated, + early_data=self.early_data_accepted, + other_extensions=self.handshake_extensions, + ), + ) + + if pre_shared_key is None: + # send certificate + with push_message(self.key_schedule, handshake_buf): + push_certificate( + handshake_buf, + Certificate( + request_context=b"", + certificates=[ + (x.public_bytes(Encoding.DER), b"") + for x in [self.certificate] + self.certificate_chain + ], + ), + ) + + # send certificate verify + signature = self.certificate_private_key.sign( + self.key_schedule.certificate_verify_data( + b"TLS 1.3, server CertificateVerify" + ), + *signature_algorithm_params(signature_algorithm), + ) + with push_message(self.key_schedule, handshake_buf): + push_certificate_verify( + handshake_buf, + CertificateVerify( + algorithm=signature_algorithm, signature=signature + ), + ) + + # send finished + with push_message(self.key_schedule, handshake_buf): + push_finished( + handshake_buf, + Finished( + verify_data=self.key_schedule.finished_verify_data(self._enc_key) + ), + ) + + # prepare traffic keys + assert self.key_schedule.generation == 2 + self.key_schedule.extract(None) + self._setup_traffic_protection( + Direction.ENCRYPT, Epoch.ONE_RTT, b"s ap traffic" + ) + self._next_dec_key = self.key_schedule.derive_secret(b"c ap traffic") + + # anticipate client's FINISHED as we don't use client auth + self._expected_verify_data = self.key_schedule.finished_verify_data( + self._dec_key + ) + buf = Buffer(capacity=64) + push_finished(buf, Finished(verify_data=self._expected_verify_data)) + self.key_schedule.update_hash(buf.data) + + # create a new session ticket + if self.new_session_ticket_cb is not None and psk_key_exchange_mode is not None: + self._new_session_ticket = NewSessionTicket( + ticket_lifetime=86400, + ticket_age_add=struct.unpack("I", os.urandom(4))[0], + ticket_nonce=b"", + ticket=os.urandom(64), + max_early_data_size=self._max_early_data, + ) + + # send messsage + push_new_session_ticket(onertt_buf, self._new_session_ticket) + + # notify application + ticket = self._build_session_ticket(self._new_session_ticket) + self.new_session_ticket_cb(ticket) + + self._set_state(State.SERVER_EXPECT_FINISHED) + + def _server_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None: + finished = pull_finished(input_buf) + + # check verify data + if finished.verify_data != self._expected_verify_data: + raise AlertDecryptError + + # commit traffic key + self._dec_key = self._next_dec_key + self._next_dec_key = None + self.update_traffic_key_cb( + Direction.DECRYPT, + Epoch.ONE_RTT, + self.key_schedule.cipher_suite, + self._dec_key, + ) + + self._set_state(State.SERVER_POST_HANDSHAKE) + + def _setup_traffic_protection( + self, direction: Direction, epoch: Epoch, label: bytes + ) -> None: + key = self.key_schedule.derive_secret(label) + + if direction == Direction.ENCRYPT: + self._enc_key = key + else: + self._dec_key = key + + self.update_traffic_key_cb( + direction, epoch, self.key_schedule.cipher_suite, key + ) + + def _set_state(self, state: State) -> None: + if self.__logger: + self.__logger.debug("TLS %s -> %s", self.state, state) + self.state = state diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/__init__.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_client.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_client.bin new file mode 100644 index 000000000000..9d7b7055cdaa Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_client.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_server.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_server.bin new file mode 100644 index 000000000000..75bbb9c0aff8 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/initial_server.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/pycacert.pem b/testing/web-platform/tests/tools/third_party/aioquic/tests/pycacert.pem new file mode 100644 index 000000000000..73150c960f35 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/pycacert.pem @@ -0,0 +1,99 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + cb:2d:80:99:5a:69:52:5b + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Aug 29 14:23:16 2018 GMT + Not After : Aug 26 14:23:16 2028 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (3072 bit) + Modulus: + 00:97:ed:55:41:ba:36:17:95:db:71:1c:d3:e1:61: + ac:58:73:e3:c6:96:cf:2b:1f:b8:08:f5:9d:4b:4b: + c7:30:f6:b8:0b:b3:52:72:a0:bb:c9:4d:3b:8e:df: + 22:8e:01:57:81:c9:92:73:cc:00:c6:ec:70:b0:3a: + 17:40:c1:df:f2:8c:36:4c:c4:a7:81:e7:b6:24:68: + e2:a0:7e:35:07:2f:a0:5b:f9:45:46:f7:1e:f0:46: + 11:fe:ca:1a:3c:50:f1:26:a9:5f:9c:22:9c:f8:41: + e1:df:4f:12:95:19:2f:5c:90:01:17:6e:7e:3e:7d: + cf:e9:09:af:25:f8:f8:42:77:2d:6d:5f:36:f2:78: + 1e:7d:4a:87:68:63:6c:06:71:1b:8d:fa:25:fe:d4: + d3:f5:a5:17:b1:ef:ea:17:cb:54:c8:27:99:80:cb: + 3c:45:f1:2c:52:1c:dd:1f:51:45:20:50:1e:5e:ab: + 57:73:1b:41:78:96:de:84:a4:7a:dd:8f:30:85:36: + 58:79:76:a0:d2:61:c8:1b:a9:94:99:63:c6:ee:f8: + 14:bf:b4:52:56:31:97:fa:eb:ac:53:9e:95:ce:4c: + c4:5a:4a:b7:ca:03:27:5b:35:57:ce:02:dc:ec:ca: + 69:f8:8a:5a:39:cb:16:20:15:03:24:61:6c:f4:7a: + fc:b6:48:e5:59:10:5c:49:d0:23:9f:fb:71:5e:3a: + e9:68:9f:34:72:80:27:b6:3f:4c:b1:d9:db:63:7f: + 67:68:4a:6e:11:f8:e8:c0:f4:5a:16:39:53:0b:68: + de:77:fa:45:e7:f8:91:cd:78:cd:28:94:97:71:54: + fb:cf:f0:37:de:c9:26:c5:dc:1b:9e:89:6d:09:ac: + c8:44:71:cb:6d:f1:97:31:d5:4c:20:33:bf:75:4a: + a0:e0:dc:69:11:ed:2a:b4:64:10:11:30:8b:0e:b0: + a7:10:d8:8a:c5:aa:1b:c8:26:8a:25:e7:66:9f:a5: + 6a:1a:2f:7c:5f:83:c6:78:4f:1f + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48 + X509v3 Authority Key Identifier: + keyid:DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48 + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha256WithRSAEncryption + 33:6a:54:d3:6b:c0:d7:01:5f:9d:f4:05:c1:93:66:90:50:d0: + b7:18:e9:b0:1e:4a:a0:b6:da:76:93:af:84:db:ad:15:54:31: + 15:13:e4:de:7e:4e:0c:d5:09:1c:34:35:b6:e5:4c:d6:6f:65: + 7d:32:5f:eb:fc:a9:6b:07:f7:49:82:e5:81:7e:07:80:9a:63: + f8:2c:c3:40:bc:8f:d4:2a:da:3e:d1:ee:08:b7:4d:a7:84:ca: + f4:3f:a1:98:45:be:b1:05:69:e7:df:d7:99:ab:1b:ee:8b:30: + cc:f7:fc:e7:d4:0b:17:ae:97:bf:e4:7b:fd:0f:a7:b4:85:79: + e3:59:e2:16:87:bf:1f:29:45:2c:23:93:76:be:c0:87:1d:de: + ec:2b:42:6a:e5:bb:c8:f4:0a:4a:08:0a:8c:5c:d8:7d:4d:d1: + b8:bf:d5:f7:29:ed:92:d1:94:04:e8:35:06:57:7f:2c:23:97: + 87:a5:35:8d:26:d3:1a:47:f2:16:d7:d9:c6:d4:1f:23:43:d3: + 26:99:39:ca:20:f4:71:23:6f:0c:4a:76:76:f7:76:1f:b3:fe: + bf:47:b0:fc:2a:56:81:e1:d2:dd:ee:08:d8:f4:ff:5a:dc:25: + 61:8a:91:02:b9:86:1c:f2:50:73:76:25:35:fc:b6:25:26:15: + cb:eb:c4:2b:61:0c:1c:e7:ee:2f:17:9b:ec:f0:d4:a1:84:e7: + d2:af:de:e4:1b:24:14:a7:01:87:e3:ab:29:58:46:a0:d9:c0: + 0a:e0:8d:d7:59:d3:1b:f8:54:20:3e:78:a5:a5:c8:4f:8b:03: + c4:96:9f:ec:fb:47:cf:76:2d:8d:65:34:27:bf:fa:ae:01:05: + 8a:f3:92:0a:dd:89:6c:97:a1:c7:e7:60:51:e7:ac:eb:4b:7d: + 2c:b8:65:c9:fe:5d:6a:48:55:8e:e4:c7:f9:6a:40:e1:b8:64: + 45:e9:b5:59:29:a5:5f:cf:7d:58:7d:64:79:e5:a4:09:ac:1e: + 76:65:3d:94:c4:68 +-----BEGIN CERTIFICATE----- +MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx +NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI +hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1 +nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo +4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ +ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc +3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr +rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7 +cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU ++8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY +isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR +NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm +kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3 +SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8 +59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY +fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv +DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE +K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI +T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV +juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo +-----END CERTIFICATE----- diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/retry.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/retry.bin new file mode 100644 index 000000000000..0f3bd4981a4a Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/retry.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/short_header.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/short_header.bin new file mode 100644 index 000000000000..283683d2d5b5 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/short_header.bin @@ -0,0 +1 @@ +]ôZ§µœÖæhõ0LÔý³y“' \ No newline at end of file diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert.pem b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert.pem new file mode 100644 index 000000000000..1f7127a4b025 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert.pem @@ -0,0 +1,34 @@ +-----BEGIN CERTIFICATE----- +MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx +NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj +MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv +Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k +YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA +3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug +U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 +pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA +hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC +WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU +NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 +EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB +wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV +HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK +b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst +gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 +Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw +AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD +VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 +Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ +uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY +oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb +iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 +KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP +IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr ++UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI +AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv +StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== +-----END CERTIFICATE----- diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert_with_chain.pem b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert_with_chain.pem new file mode 100644 index 000000000000..b60ad41687e2 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_cert_with_chain.pem @@ -0,0 +1,60 @@ +-----BEGIN CERTIFICATE----- +MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx +NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj +MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv +Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k +YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA +3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug +U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2 +pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA +hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC +WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU +NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3 +EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB +wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV +HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK +b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m +dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst +gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0 +Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw +AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD +VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0 +Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/ +uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY +oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb +iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0 +KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP +IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr ++UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI +AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv +StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx +NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI +hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1 +nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo +4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ +ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc +3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr +rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7 +cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU ++8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY +isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR +NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm +kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3 +SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8 +59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY +fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv +DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE +K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI +T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV +juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo +-----END CERTIFICATE----- diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_key.pem b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_key.pem new file mode 100644 index 000000000000..b8fdec101c8d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/ssl_key.pem @@ -0,0 +1,40 @@ +-----BEGIN PRIVATE KEY----- +MIIG/QIBADANBgkqhkiG9w0BAQEFAASCBucwggbjAgEAAoIBgQCfKC83Qe9/ZGMW +YhbpARRiKco6mJI9CNNeaf7A89TE+w5Y3GSwS8uzqp5C6QebZzPNueg8HYoTwN85 +Z3xM036/Qw9KhQVth+XDAqM+19e5KHkYcxg3d3ZI1HgY170eakaLBvMDN5ULoFOw +Is2PtwM2o9cjd5mfSuWttI6+fCqop8/l8cerG9iX2GH39p3iWwWoTZuYndAA9qYv +07YWajuQ1ESWKPjHYGTnMvu4xIzibC1mXd2M6u/IjNO6g426SKFaRDWQkx01gIV/ +CyKs9DgZoeMHkKZuPqZVOxOK+A/NrmrqHFsPIsrs5wk7QAVju5/X1skpn/UGQlmM +RwBaQULOs1FagA+54RXU6qUPW0YmhJ4xOB4gHHD1vjAKEsRZ7/6zcxMyOm+M1DbK +RTH4NWjVWpnY8XaVGdRhtTpH9MjycpKhF+D2Zdy2tQXtqu2GdcMnUedt13fn9xDu +P4PophE0ip/IMgn+kb4m9e+S+K9lldQl0B+4BcGWAqHelh2KuU0CAwEAAQKCAYEA +lKiWIYjmyRjdLKUGPTES9vWNvNmRjozV0RQ0LcoSbMMLDZkeO0UwyWqOVHUQ8+ib +jIcfEjeNJxI57oZopeHOO5vJhpNlFH+g7ltiW2qERqA1K88lSXm99Bzw6FNqhCRE +K8ub5N9fyfJA+P4o/xm0WK8EXk5yIUV17p/9zJJxzgKgv2jsVTi3QG2OZGvn4Oug +ByomMZEGHkBDzdxz8c/cP1Tlk1RFuwSgews178k2xq7AYSM/s0YmHi7b/RSvptX6 +1v8P8kXNUe4AwTaNyrlvF2lwIadZ8h1hA7tCE2n44b7a7KfhAkwcbr1T59ioYh6P +zxsyPT678uD51dbtD/DXJCcoeeFOb8uzkR2KNcrnQzZpCJnRq4Gp5ybxwsxxuzpr +gz0gbNlhuWtE7EoSzmIK9t+WTS7IM2CvZymd6/OAh1Fuw6AQhSp64XRp3OfMMAAC +Ie2EPtKj4islWGT8VoUjuRYGmdRh4duAH1dkiAXOWA3R7y5a1/y/iE8KE8BtxocB +AoHBAM8aiURgpu1Fs0Oqz6izec7KSLL3l8hmW+MKUOfk/Ybng6FrTFsL5YtzR+Ap +wW4wwWnnIKEc1JLiZ7g8agRETK8hr5PwFXUn/GSWC0SMsazLJToySQS5LOV0tLzK +kJ3jtNU7tnlDGNkCHTHSoVL2T/8t+IkZI/h5Z6wjlYPvU2Iu0nVIXtiG+alv4A6M +Hrh9l5or4mjB6rGnVXeYohLkCm6s/W97ahVxLMcEdbsBo1prm2JqGnSoiR/tEFC/ +QHQnbQKBwQDEu7kW0Yg9sZ89QtYtVQ1YpixFZORaUeRIRLnpEs1w7L1mCbOZ2Lj9 +JHxsH05cYAc7HJfPwwxv3+3aGAIC/dfu4VSwEFtatAzUpzlhzKS5+HQCWB4JUNNU +MQ3+FwK2xQX4Ph8t+OzrFiYcK2g0An5UxWMa2HWIAWUOhnTOydAVsoH6yP31cVm4 +0hxoABCwflaNLNGjRUyfBpLTAcNu/YtcE+KREy7YAAgXXrhRSO4XpLsSXwLnLT7/ +YOkoBWDcTWECgcBPWnSUDZCIQ3efithMZJBciqd2Y2X19Dpq8O31HImD4jtOY0V7 +cUB/wSkeHAGwjd/eCyA2e0x8B2IEdqmMfvr+86JJxekC3dJYXCFvH5WIhsH53YCa +3bT1KlWCLP9ib/g+58VQC0R/Cc9T4sfLePNH7D5ZkZd1wlbV30CPr+i8KwKay6MD +xhvtLx+jk07GE+E9wmjbCMo7TclyrLoVEOlqZMAqshgApT+p9eyCPetwXuDHwa3n +WxhHclcZCV7R4rUCgcAkdGSnxcvpIrDPOUNWwxvmAWTStw9ZbTNP8OxCNCm9cyDl +d4bAS1h8D/a+Uk7C70hnu7Sl2w7C7Eu2zhwRUdhhe3+l4GINPK/j99i6NqGPlGpq +xMlMEJ4YS768BqeKFpg0l85PRoEgTsphDeoROSUPsEPdBZ9BxIBlYKTkbKESZDGR +twzYHljx1n1NCDYPflmrb1KpXn4EOcObNghw2KqqNUUWfOeBPwBA1FxzM4BrAStp +DBINpGS4Dc0mjViVegECgcA3hTtm82XdxQXj9LQmb/E3lKx/7H87XIOeNMmvjYuZ +iS9wKrkF+u42vyoDxcKMCnxP5056wpdST4p56r+SBwVTHcc3lGBSGcMTIfwRXrj3 +thOA2our2n4ouNIsYyTlcsQSzifwmpRmVMRPxl9fYVdEWUgB83FgHT0D9avvZnF9 +t9OccnGJXShAIZIBADhVj/JwG4FbaX42NijD5PNpVLk1Y17OV0I576T9SfaQoBjJ +aH1M/zC4aVaS0DYB/Gxq7v8= +-----END PRIVATE KEY----- diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio.py new file mode 100644 index 000000000000..81fcbf3803dc --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio.py @@ -0,0 +1,374 @@ +import asyncio +import binascii +import random +import socket +from unittest import TestCase, skipIf +from unittest.mock import patch + +from cryptography.hazmat.primitives import serialization + +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.asyncio.server import serve +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.logger import QuicLogger + +from .utils import ( + SERVER_CACERTFILE, + SERVER_CERTFILE, + SERVER_KEYFILE, + SKIP_TESTS, + generate_ec_certificate, + run, +) + +real_sendto = socket.socket.sendto + + +def sendto_with_loss(self, data, addr=None): + """ + Simulate 25% packet loss. + """ + if random.random() > 0.25: + real_sendto(self, data, addr) + + +class SessionTicketStore: + def __init__(self): + self.tickets = {} + + def add(self, ticket): + self.tickets[ticket.ticket] = ticket + + def pop(self, label): + return self.tickets.pop(label, None) + + +def handle_stream(reader, writer): + async def serve(): + data = await reader.read() + writer.write(bytes(reversed(data))) + writer.write_eof() + + asyncio.ensure_future(serve()) + + +class HighLevelTest(TestCase): + def setUp(self): + self.server = None + self.server_host = "localhost" + self.server_port = 4433 + + def tearDown(self): + if self.server is not None: + self.server.close() + + async def run_client( + self, + host=None, + port=None, + cadata=None, + cafile=SERVER_CACERTFILE, + configuration=None, + request=b"ping", + **kwargs + ): + if host is None: + host = self.server_host + if port is None: + port = self.server_port + if configuration is None: + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cadata=cadata, cafile=cafile) + async with connect(host, port, configuration=configuration, **kwargs) as client: + # waiting for connected when connected returns immediately + await client.wait_connected() + + reader, writer = await client.create_stream() + self.assertEqual(writer.can_write_eof(), True) + self.assertEqual(writer.get_extra_info("stream_id"), 0) + + writer.write(request) + writer.write_eof() + + response = await reader.read() + + # waiting for closed when closed returns immediately + await client.wait_closed() + + return response + + async def run_server(self, configuration=None, host="::", **kwargs): + if configuration is None: + configuration = QuicConfiguration(is_client=False) + configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + self.server = await serve( + host=host, + port=self.server_port, + configuration=configuration, + stream_handler=handle_stream, + **kwargs + ) + return self.server + + def test_connect_and_serve(self): + run(self.run_server()) + response = run(self.run_client()) + self.assertEqual(response, b"gnip") + + def test_connect_and_serve_ipv4(self): + run(self.run_server(host="0.0.0.0")) + response = run(self.run_client(host="127.0.0.1")) + self.assertEqual(response, b"gnip") + + @skipIf("ipv6" in SKIP_TESTS, "Skipping IPv6 tests") + def test_connect_and_serve_ipv6(self): + run(self.run_server(host="::")) + response = run(self.run_client(host="::1")) + self.assertEqual(response, b"gnip") + + def test_connect_and_serve_ec_certificate(self): + certificate, private_key = generate_ec_certificate(common_name="localhost") + + run( + self.run_server( + configuration=QuicConfiguration( + certificate=certificate, private_key=private_key, is_client=False, + ) + ) + ) + + response = run( + self.run_client( + cadata=certificate.public_bytes(serialization.Encoding.PEM), + cafile=None, + ) + ) + + self.assertEqual(response, b"gnip") + + def test_connect_and_serve_large(self): + """ + Transfer enough data to require raising MAX_DATA and MAX_STREAM_DATA. + """ + data = b"Z" * 2097152 + run(self.run_server()) + response = run(self.run_client(request=data)) + self.assertEqual(response, data) + + def test_connect_and_serve_without_client_configuration(self): + async def run_client_without_config(): + async with connect(self.server_host, self.server_port) as client: + await client.ping() + + run(self.run_server()) + with self.assertRaises(ConnectionError): + run(run_client_without_config()) + + def test_connect_and_serve_writelines(self): + async def run_client_writelines(): + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, self.server_port, configuration=configuration + ) as client: + reader, writer = await client.create_stream() + assert writer.can_write_eof() is True + + writer.writelines([b"01234567", b"89012345"]) + writer.write_eof() + + return await reader.read() + + run(self.run_server()) + response = run(run_client_writelines()) + self.assertEqual(response, b"5432109876543210") + + @skipIf("loss" in SKIP_TESTS, "Skipping loss tests") + @patch("socket.socket.sendto", new_callable=lambda: sendto_with_loss) + def test_connect_and_serve_with_packet_loss(self, mock_sendto): + """ + This test ensures handshake success and stream data is successfully sent + and received in the presence of packet loss (randomized 25% in each direction). + """ + data = b"Z" * 65536 + + server_configuration = QuicConfiguration( + is_client=False, quic_logger=QuicLogger() + ) + server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + run(self.run_server(configuration=server_configuration)) + + response = run( + self.run_client( + configuration=QuicConfiguration( + is_client=True, quic_logger=QuicLogger() + ), + request=data, + ) + ) + self.assertEqual(response, data) + + def test_connect_and_serve_with_session_ticket(self): + # start server + client_ticket = None + store = SessionTicketStore() + + def save_ticket(t): + nonlocal client_ticket + client_ticket = t + + run( + self.run_server( + session_ticket_fetcher=store.pop, session_ticket_handler=store.add + ) + ) + + # first request + response = run(self.run_client(session_ticket_handler=save_ticket),) + self.assertEqual(response, b"gnip") + + self.assertIsNotNone(client_ticket) + + # second request + run( + self.run_client( + configuration=QuicConfiguration( + is_client=True, session_ticket=client_ticket + ), + ) + ) + self.assertEqual(response, b"gnip") + + def test_connect_and_serve_with_stateless_retry(self): + run(self.run_server()) + response = run(self.run_client()) + self.assertEqual(response, b"gnip") + + def test_connect_and_serve_with_stateless_retry_bad_original_connection_id(self): + """ + If the server's transport parameters do not have the correct + original_connection_id the connection fail. + """ + + def create_protocol(*args, **kwargs): + protocol = QuicConnectionProtocol(*args, **kwargs) + protocol._quic._original_connection_id = None + return protocol + + run(self.run_server(create_protocol=create_protocol, stateless_retry=True)) + with self.assertRaises(ConnectionError): + run(self.run_client()) + + @patch("aioquic.quic.retry.QuicRetryTokenHandler.validate_token") + def test_connect_and_serve_with_stateless_retry_bad(self, mock_validate): + mock_validate.side_effect = ValueError("Decryption failed.") + + run(self.run_server(stateless_retry=True)) + with self.assertRaises(ConnectionError): + run( + self.run_client( + configuration=QuicConfiguration(is_client=True, idle_timeout=4.0), + ) + ) + + def test_connect_and_serve_with_version_negotiation(self): + run(self.run_server()) + + # force version negotiation + configuration = QuicConfiguration(is_client=True, quic_logger=QuicLogger()) + configuration.supported_versions.insert(0, 0x1A2A3A4A) + + response = run(self.run_client(configuration=configuration)) + self.assertEqual(response, b"gnip") + + def test_connect_timeout(self): + with self.assertRaises(ConnectionError): + run( + self.run_client( + port=4400, + configuration=QuicConfiguration(is_client=True, idle_timeout=5), + ) + ) + + def test_connect_timeout_no_wait_connected(self): + async def run_client_no_wait_connected(configuration): + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, + 4400, + configuration=configuration, + wait_connected=False, + ) as client: + await client.ping() + + with self.assertRaises(ConnectionError): + run( + run_client_no_wait_connected( + configuration=QuicConfiguration(is_client=True, idle_timeout=5), + ) + ) + + def test_connect_local_port(self): + run(self.run_server()) + response = run(self.run_client(local_port=3456)) + self.assertEqual(response, b"gnip") + + def test_change_connection_id(self): + async def run_client_change_connection_id(): + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, self.server_port, configuration=configuration + ) as client: + await client.ping() + client.change_connection_id() + await client.ping() + + run(self.run_server()) + run(run_client_change_connection_id()) + + def test_key_update(self): + async def run_client_key_update(): + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, self.server_port, configuration=configuration + ) as client: + await client.ping() + client.request_key_update() + await client.ping() + + run(self.run_server()) + run(run_client_key_update()) + + def test_ping(self): + async def run_client_ping(): + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, self.server_port, configuration=configuration + ) as client: + await client.ping() + await client.ping() + + run(self.run_server()) + run(run_client_ping()) + + def test_ping_parallel(self): + async def run_client_ping(): + configuration = QuicConfiguration(is_client=True) + configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + async with connect( + self.server_host, self.server_port, configuration=configuration + ) as client: + coros = [client.ping() for x in range(16)] + await asyncio.gather(*coros) + + run(self.run_server()) + run(run_client_ping()) + + def test_server_receives_garbage(self): + server = run(self.run_server()) + server.datagram_received(binascii.unhexlify("c00000000080"), ("1.2.3.4", 1234)) + server.close() diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio_compat.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio_compat.py new file mode 100644 index 000000000000..906dbf51124c --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_asyncio_compat.py @@ -0,0 +1,38 @@ +import asyncio +from unittest import TestCase + +from aioquic.asyncio.compat import _asynccontextmanager + +from .utils import run + + +@_asynccontextmanager +async def some_context(): + await asyncio.sleep(0) + yield + await asyncio.sleep(0) + + +class AsyncioCompatTest(TestCase): + def test_ok(self): + async def test(): + async with some_context(): + pass + + run(test()) + + def test_raise_exception(self): + async def test(): + async with some_context(): + raise RuntimeError("some reason") + + with self.assertRaises(RuntimeError): + run(test()) + + def test_raise_exception_type(self): + async def test(): + async with some_context(): + raise RuntimeError + + with self.assertRaises(RuntimeError): + run(test()) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_buffer.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_buffer.py new file mode 100644 index 000000000000..51dbc44d0dd2 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_buffer.py @@ -0,0 +1,201 @@ +from unittest import TestCase + +from aioquic.buffer import Buffer, BufferReadError, BufferWriteError, size_uint_var + + +class BufferTest(TestCase): + def test_data_slice(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.data_slice(0, 8), b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.data_slice(1, 3), b"\x07\x06") + + with self.assertRaises(BufferReadError): + buf.data_slice(-1, 3) + with self.assertRaises(BufferReadError): + buf.data_slice(0, 9) + with self.assertRaises(BufferReadError): + buf.data_slice(1, 0) + + def test_pull_bytes(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_bytes(3), b"\x08\x07\x06") + + def test_pull_bytes_negative(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + with self.assertRaises(BufferReadError): + buf.pull_bytes(-1) + + def test_pull_bytes_truncated(self): + buf = Buffer(capacity=0) + with self.assertRaises(BufferReadError): + buf.pull_bytes(2) + self.assertEqual(buf.tell(), 0) + + def test_pull_bytes_zero(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_bytes(0), b"") + + def test_pull_uint8(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_uint8(), 0x08) + self.assertEqual(buf.tell(), 1) + + def test_pull_uint8_truncated(self): + buf = Buffer(capacity=0) + with self.assertRaises(BufferReadError): + buf.pull_uint8() + self.assertEqual(buf.tell(), 0) + + def test_pull_uint16(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_uint16(), 0x0807) + self.assertEqual(buf.tell(), 2) + + def test_pull_uint16_truncated(self): + buf = Buffer(capacity=1) + with self.assertRaises(BufferReadError): + buf.pull_uint16() + self.assertEqual(buf.tell(), 0) + + def test_pull_uint32(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_uint32(), 0x08070605) + self.assertEqual(buf.tell(), 4) + + def test_pull_uint32_truncated(self): + buf = Buffer(capacity=3) + with self.assertRaises(BufferReadError): + buf.pull_uint32() + self.assertEqual(buf.tell(), 0) + + def test_pull_uint64(self): + buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.pull_uint64(), 0x0807060504030201) + self.assertEqual(buf.tell(), 8) + + def test_pull_uint64_truncated(self): + buf = Buffer(capacity=7) + with self.assertRaises(BufferReadError): + buf.pull_uint64() + self.assertEqual(buf.tell(), 0) + + def test_push_bytes(self): + buf = Buffer(capacity=3) + buf.push_bytes(b"\x08\x07\x06") + self.assertEqual(buf.data, b"\x08\x07\x06") + self.assertEqual(buf.tell(), 3) + + def test_push_bytes_truncated(self): + buf = Buffer(capacity=3) + with self.assertRaises(BufferWriteError): + buf.push_bytes(b"\x08\x07\x06\x05") + self.assertEqual(buf.tell(), 0) + + def test_push_bytes_zero(self): + buf = Buffer(capacity=3) + buf.push_bytes(b"") + self.assertEqual(buf.data, b"") + self.assertEqual(buf.tell(), 0) + + def test_push_uint8(self): + buf = Buffer(capacity=1) + buf.push_uint8(0x08) + self.assertEqual(buf.data, b"\x08") + self.assertEqual(buf.tell(), 1) + + def test_push_uint16(self): + buf = Buffer(capacity=2) + buf.push_uint16(0x0807) + self.assertEqual(buf.data, b"\x08\x07") + self.assertEqual(buf.tell(), 2) + + def test_push_uint32(self): + buf = Buffer(capacity=4) + buf.push_uint32(0x08070605) + self.assertEqual(buf.data, b"\x08\x07\x06\x05") + self.assertEqual(buf.tell(), 4) + + def test_push_uint64(self): + buf = Buffer(capacity=8) + buf.push_uint64(0x0807060504030201) + self.assertEqual(buf.data, b"\x08\x07\x06\x05\x04\x03\x02\x01") + self.assertEqual(buf.tell(), 8) + + def test_seek(self): + buf = Buffer(data=b"01234567") + self.assertFalse(buf.eof()) + self.assertEqual(buf.tell(), 0) + + buf.seek(4) + self.assertFalse(buf.eof()) + self.assertEqual(buf.tell(), 4) + + buf.seek(8) + self.assertTrue(buf.eof()) + self.assertEqual(buf.tell(), 8) + + with self.assertRaises(BufferReadError): + buf.seek(-1) + self.assertEqual(buf.tell(), 8) + with self.assertRaises(BufferReadError): + buf.seek(9) + self.assertEqual(buf.tell(), 8) + + +class UintVarTest(TestCase): + def roundtrip(self, data, value): + buf = Buffer(data=data) + self.assertEqual(buf.pull_uint_var(), value) + self.assertEqual(buf.tell(), len(data)) + + buf = Buffer(capacity=8) + buf.push_uint_var(value) + self.assertEqual(buf.data, data) + + def test_uint_var(self): + # 1 byte + self.roundtrip(b"\x00", 0) + self.roundtrip(b"\x01", 1) + self.roundtrip(b"\x25", 37) + self.roundtrip(b"\x3f", 63) + + # 2 bytes + self.roundtrip(b"\x7b\xbd", 15293) + self.roundtrip(b"\x7f\xff", 16383) + + # 4 bytes + self.roundtrip(b"\x9d\x7f\x3e\x7d", 494878333) + self.roundtrip(b"\xbf\xff\xff\xff", 1073741823) + + # 8 bytes + self.roundtrip(b"\xc2\x19\x7c\x5e\xff\x14\xe8\x8c", 151288809941952652) + self.roundtrip(b"\xff\xff\xff\xff\xff\xff\xff\xff", 4611686018427387903) + + def test_pull_uint_var_truncated(self): + buf = Buffer(capacity=0) + with self.assertRaises(BufferReadError): + buf.pull_uint_var() + + buf = Buffer(data=b"\xff") + with self.assertRaises(BufferReadError): + buf.pull_uint_var() + + def test_push_uint_var_too_big(self): + buf = Buffer(capacity=8) + with self.assertRaises(ValueError) as cm: + buf.push_uint_var(4611686018427387904) + self.assertEqual( + str(cm.exception), "Integer is too big for a variable-length integer" + ) + + def test_size_uint_var(self): + self.assertEqual(size_uint_var(63), 1) + self.assertEqual(size_uint_var(16383), 2) + self.assertEqual(size_uint_var(1073741823), 4) + self.assertEqual(size_uint_var(4611686018427387903), 8) + + with self.assertRaises(ValueError) as cm: + size_uint_var(4611686018427387904) + self.assertEqual( + str(cm.exception), "Integer is too big for a variable-length integer" + ) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_connection.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_connection.py new file mode 100644 index 000000000000..af1a5cdc58f5 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_connection.py @@ -0,0 +1,1800 @@ +import asyncio +import binascii +import contextlib +import io +import time +from unittest import TestCase + +from aioquic import tls +from aioquic.buffer import UINT_VAR_MAX, Buffer, encode_uint_var +from aioquic.quic import events +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import ( + QuicConnection, + QuicConnectionError, + QuicNetworkPath, + QuicReceiveContext, +) +from aioquic.quic.crypto import CryptoPair +from aioquic.quic.logger import QuicLogger +from aioquic.quic.packet import ( + PACKET_TYPE_INITIAL, + QuicErrorCode, + QuicFrameType, + encode_quic_retry, + encode_quic_version_negotiation, +) +from aioquic.quic.packet_builder import QuicDeliveryState, QuicPacketBuilder +from aioquic.quic.recovery import QuicPacketPacer + +from .utils import ( + SERVER_CACERTFILE, + SERVER_CERTFILE, + SERVER_CERTFILE_WITH_CHAIN, + SERVER_KEYFILE, +) + +CLIENT_ADDR = ("1.2.3.4", 1234) + +SERVER_ADDR = ("2.3.4.5", 4433) + + +class SessionTicketStore: + def __init__(self): + self.tickets = {} + + def add(self, ticket): + self.tickets[ticket.ticket] = ticket + + def pop(self, label): + return self.tickets.pop(label, None) + + +def client_receive_context(client, epoch=tls.Epoch.ONE_RTT): + return QuicReceiveContext( + epoch=epoch, + host_cid=client.host_cid, + network_path=client._network_paths[0], + quic_logger_frames=[], + time=asyncio.get_event_loop().time(), + ) + + +def consume_events(connection): + while True: + event = connection.next_event() + if event is None: + break + + +def create_standalone_client(self, **client_options): + client = QuicConnection( + configuration=QuicConfiguration( + is_client=True, quic_logger=QuicLogger(), **client_options + ) + ) + client._ack_delay = 0 + + # kick-off handshake + client.connect(SERVER_ADDR, now=time.time()) + self.assertEqual(drop(client), 1) + + return client + + +@contextlib.contextmanager +def client_and_server( + client_kwargs={}, + client_options={}, + client_patch=lambda x: None, + handshake=True, + server_kwargs={}, + server_certfile=SERVER_CERTFILE, + server_keyfile=SERVER_KEYFILE, + server_options={}, + server_patch=lambda x: None, +): + client_configuration = QuicConfiguration( + is_client=True, quic_logger=QuicLogger(), **client_options + ) + client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + + client = QuicConnection(configuration=client_configuration, **client_kwargs) + client._ack_delay = 0 + disable_packet_pacing(client) + client_patch(client) + + server_configuration = QuicConfiguration( + is_client=False, quic_logger=QuicLogger(), **server_options + ) + server_configuration.load_cert_chain(server_certfile, server_keyfile) + + server = QuicConnection(configuration=server_configuration, **server_kwargs) + server._ack_delay = 0 + disable_packet_pacing(server) + server_patch(server) + + # perform handshake + if handshake: + client.connect(SERVER_ADDR, now=time.time()) + for i in range(3): + roundtrip(client, server) + + yield client, server + + # close + client.close() + server.close() + + +def disable_packet_pacing(connection): + class DummyPacketPacer(QuicPacketPacer): + def next_send_time(self, now): + return None + + connection._loss._pacer = DummyPacketPacer() + + +def sequence_numbers(connection_ids): + return list(map(lambda x: x.sequence_number, connection_ids)) + + +def drop(sender): + """ + Drop datagrams from `sender`. + """ + return len(sender.datagrams_to_send(now=time.time())) + + +def roundtrip(sender, receiver): + """ + Send datagrams from `sender` to `receiver` and back. + """ + return (transfer(sender, receiver), transfer(receiver, sender)) + + +def transfer(sender, receiver): + """ + Send datagrams from `sender` to `receiver`. + """ + datagrams = 0 + from_addr = CLIENT_ADDR if sender._is_client else SERVER_ADDR + for data, addr in sender.datagrams_to_send(now=time.time()): + datagrams += 1 + receiver.receive_datagram(data, from_addr, now=time.time()) + return datagrams + + +class QuicConnectionTest(TestCase): + def check_handshake(self, client, server, alpn_protocol=None): + """ + Check handshake completed. + """ + event = client.next_event() + self.assertEqual(type(event), events.ProtocolNegotiated) + self.assertEqual(event.alpn_protocol, alpn_protocol) + event = client.next_event() + self.assertEqual(type(event), events.HandshakeCompleted) + self.assertEqual(event.alpn_protocol, alpn_protocol) + self.assertEqual(event.early_data_accepted, False) + self.assertEqual(event.session_resumed, False) + for i in range(7): + self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) + self.assertIsNone(client.next_event()) + + event = server.next_event() + self.assertEqual(type(event), events.ProtocolNegotiated) + self.assertEqual(event.alpn_protocol, alpn_protocol) + event = server.next_event() + self.assertEqual(type(event), events.HandshakeCompleted) + self.assertEqual(event.alpn_protocol, alpn_protocol) + for i in range(7): + self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) + self.assertIsNone(server.next_event()) + + def test_connect(self): + with client_and_server() as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + + # check each endpoint has available connection IDs for the peer + self.assertEqual( + sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) + self.assertEqual( + sequence_numbers(server._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) + + # client closes the connection + client.close() + self.assertEqual(transfer(client, server), 1) + + # check connection closes on the client side + client.handle_timer(client.get_timer()) + event = client.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(event.frame_type, None) + self.assertEqual(event.reason_phrase, "") + self.assertIsNone(client.next_event()) + + # check connection closes on the server side + server.handle_timer(server.get_timer()) + event = server.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(event.frame_type, None) + self.assertEqual(event.reason_phrase, "") + self.assertIsNone(server.next_event()) + + # check client log + client_log = client.configuration.quic_logger.to_dict() + self.assertGreater(len(client_log["traces"][0]["events"]), 20) + + # check server log + server_log = server.configuration.quic_logger.to_dict() + self.assertGreater(len(server_log["traces"][0]["events"]), 20) + + def test_connect_with_alpn(self): + with client_and_server( + client_options={"alpn_protocols": ["h3-25", "hq-25"]}, + server_options={"alpn_protocols": ["hq-25"]}, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server, alpn_protocol="hq-25") + + def test_connect_with_secrets_log(self): + client_log_file = io.StringIO() + server_log_file = io.StringIO() + with client_and_server( + client_options={"secrets_log_file": client_log_file}, + server_options={"secrets_log_file": server_log_file}, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server) + + # check secrets were logged + client_log = client_log_file.getvalue() + server_log = server_log_file.getvalue() + self.assertEqual(client_log, server_log) + labels = [] + for line in client_log.splitlines(): + labels.append(line.split()[0]) + self.assertEqual( + labels, + [ + "QUIC_SERVER_HANDSHAKE_TRAFFIC_SECRET", + "QUIC_CLIENT_HANDSHAKE_TRAFFIC_SECRET", + "QUIC_SERVER_TRAFFIC_SECRET_0", + "QUIC_CLIENT_TRAFFIC_SECRET_0", + ], + ) + + def test_connect_with_cert_chain(self): + with client_and_server(server_certfile=SERVER_CERTFILE_WITH_CHAIN) as ( + client, + server, + ): + # check handshake completed + self.check_handshake(client=client, server=server) + + def test_connect_with_loss_1(self): + """ + Check connection is established even in the client's INITIAL is lost. + """ + + def datagram_sizes(items): + return [len(x[0]) for x in items] + + client_configuration = QuicConfiguration(is_client=True) + client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + + client = QuicConnection(configuration=client_configuration) + client._ack_delay = 0 + + server_configuration = QuicConfiguration(is_client=False) + server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + + server = QuicConnection(configuration=server_configuration) + server._ack_delay = 0 + + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280]) + self.assertEqual(client.get_timer(), 1.0) + + # INITIAL is lost + now = 1.0 + client.handle_timer(now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280]) + self.assertEqual(client.get_timer(), 3.0) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now = 1.1 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1062]) + self.assertEqual(server.get_timer(), 2.1) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) + self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) + self.assertIsNone(server.next_event()) + + # handshake continues normally + now = 1.2 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [376]) + self.assertAlmostEqual(client.get_timer(), 1.825) + self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) + self.assertEqual(type(client.next_event()), events.HandshakeCompleted) + self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) + + now = 1.3 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 1.825) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) + self.assertEqual(type(server.next_event()), events.HandshakeCompleted) + self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) + + now = 1.4 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 61.4) # idle timeout + + def test_connect_with_loss_2(self): + def datagram_sizes(items): + return [len(x[0]) for x in items] + + client_configuration = QuicConfiguration(is_client=True) + client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + + client = QuicConnection(configuration=client_configuration) + client._ack_delay = 0 + + server_configuration = QuicConfiguration(is_client=False) + server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + + server = QuicConnection(configuration=server_configuration) + server._ack_delay = 0 + + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280]) + self.assertEqual(client.get_timer(), 1.0) + + # server receives INITIAL, sends INITIAL + HANDSHAKE but second datagram is lost + now = 0.1 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1062]) + self.assertEqual(server.get_timer(), 1.1) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) + + # client only receives first datagram and sends ACKS + now = 0.2 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [97]) + self.assertAlmostEqual(client.get_timer(), 0.625) + self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) + self.assertIsNone(client.next_event()) + + # client PTO - HANDSHAKE PING + now = client.get_timer() # ~0.625 + client.handle_timer(now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [44]) + self.assertAlmostEqual(client.get_timer(), 1.875) + + # server receives PING, discards INITIAL and sends ACK + now = 0.725 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [48]) + self.assertAlmostEqual(server.get_timer(), 1.1) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 3) + self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) + self.assertIsNone(server.next_event()) + + # ACKs are lost, server retransmits HANDSHAKE + now = server.get_timer() + server.handle_timer(now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 854]) + self.assertAlmostEqual(server.get_timer(), 3.1) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 3) + self.assertIsNone(server.next_event()) + + # handshake continues normally + now = 1.2 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [329]) + self.assertAlmostEqual(client.get_timer(), 2.45) + self.assertEqual(type(client.next_event()), events.HandshakeCompleted) + self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) + + now = 1.3 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 1.925) + self.assertEqual(type(server.next_event()), events.HandshakeCompleted) + self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) + + now = 1.4 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 61.4) # idle timeout + + def test_connect_with_loss_3(self): + def datagram_sizes(items): + return [len(x[0]) for x in items] + + client_configuration = QuicConfiguration(is_client=True) + client_configuration.load_verify_locations(cafile=SERVER_CACERTFILE) + + client = QuicConnection(configuration=client_configuration) + client._ack_delay = 0 + + server_configuration = QuicConfiguration(is_client=False) + server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + + server = QuicConnection(configuration=server_configuration) + server._ack_delay = 0 + + # client sends INITIAL + now = 0.0 + client.connect(SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280]) + self.assertEqual(client.get_timer(), 1.0) + + # server receives INITIAL, sends INITIAL + HANDSHAKE + now = 0.1 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [1280, 1062]) + self.assertEqual(server.get_timer(), 1.1) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 1) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 2) + + # client receives INITIAL + HANDSHAKE + now = 0.2 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + client.receive_datagram(items[1][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [376]) + self.assertAlmostEqual(client.get_timer(), 0.825) + self.assertEqual(type(client.next_event()), events.ProtocolNegotiated) + self.assertEqual(type(client.next_event()), events.HandshakeCompleted) + self.assertEqual(type(client.next_event()), events.ConnectionIdIssued) + + # server completes handshake + now = 0.3 + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [229]) + self.assertAlmostEqual(server.get_timer(), 0.825) + self.assertEqual(len(server._loss.spaces[0].sent_packets), 0) + self.assertEqual(len(server._loss.spaces[1].sent_packets), 0) + self.assertEqual(type(server.next_event()), events.ProtocolNegotiated) + self.assertEqual(type(server.next_event()), events.HandshakeCompleted) + self.assertEqual(type(server.next_event()), events.ConnectionIdIssued) + + # server PTO - 1-RTT PING + now = 0.825 + server.handle_timer(now=now) + items = server.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [29]) + self.assertAlmostEqual(server.get_timer(), 1.875) + + # client receives PING, sends ACK + now = 0.9 + client.receive_datagram(items[0][0], SERVER_ADDR, now=now) + items = client.datagrams_to_send(now=now) + self.assertEqual(datagram_sizes(items), [32]) + self.assertAlmostEqual(client.get_timer(), 0.825) + + # server receives ACK, retransmits HANDSHAKE_DONE + now = 1.0 + self.assertFalse(server._handshake_done_pending) + server.receive_datagram(items[0][0], CLIENT_ADDR, now=now) + self.assertTrue(server._handshake_done_pending) + items = server.datagrams_to_send(now=now) + self.assertFalse(server._handshake_done_pending) + self.assertEqual(datagram_sizes(items), [224]) + + def test_connect_with_quantum_readiness(self): + with client_and_server(client_options={"quantum_readiness_test": True},) as ( + client, + server, + ): + stream_id = client.get_next_available_stream_id() + client.send_stream_data(stream_id, b"hello") + + self.assertEqual(roundtrip(client, server), (1, 1)) + + received = None + while True: + event = server.next_event() + if isinstance(event, events.StreamDataReceived): + received = event.data + elif event is None: + break + + self.assertEqual(received, b"hello") + + def test_connect_with_0rtt(self): + client_ticket = None + ticket_store = SessionTicketStore() + + def save_session_ticket(ticket): + nonlocal client_ticket + client_ticket = ticket + + with client_and_server( + client_kwargs={"session_ticket_handler": save_session_ticket}, + server_kwargs={"session_ticket_handler": ticket_store.add}, + ) as (client, server): + pass + + with client_and_server( + client_options={"session_ticket": client_ticket}, + server_kwargs={"session_ticket_fetcher": ticket_store.pop}, + handshake=False, + ) as (client, server): + client.connect(SERVER_ADDR, now=time.time()) + stream_id = client.get_next_available_stream_id() + client.send_stream_data(stream_id, b"hello") + + self.assertEqual(roundtrip(client, server), (2, 1)) + + event = server.next_event() + self.assertEqual(type(event), events.ProtocolNegotiated) + + event = server.next_event() + self.assertEqual(type(event), events.StreamDataReceived) + self.assertEqual(event.data, b"hello") + + def test_connect_with_0rtt_bad_max_early_data(self): + client_ticket = None + ticket_store = SessionTicketStore() + + def patch(server): + """ + Patch server's TLS initialization to set an invalid + max_early_data value. + """ + real_initialize = server._initialize + + def patched_initialize(peer_cid: bytes): + real_initialize(peer_cid) + server.tls._max_early_data = 12345 + + server._initialize = patched_initialize + + def save_session_ticket(ticket): + nonlocal client_ticket + client_ticket = ticket + + with client_and_server( + client_kwargs={"session_ticket_handler": save_session_ticket}, + server_kwargs={"session_ticket_handler": ticket_store.add}, + server_patch=patch, + ) as (client, server): + # check handshake failed + event = client.next_event() + self.assertIsNone(event) + + def test_change_connection_id(self): + with client_and_server() as (client, server): + self.assertEqual( + sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) + + # the client changes connection ID + client.change_connection_id() + self.assertEqual(transfer(client, server), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] + ) + + # the server provides a new connection ID + self.assertEqual(transfer(server, client), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] + ) + + def test_change_connection_id_retransmit_new_connection_id(self): + with client_and_server() as (client, server): + self.assertEqual( + sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) + + # the client changes connection ID + client.change_connection_id() + self.assertEqual(transfer(client, server), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] + ) + + # the server provides a new connection ID, NEW_CONNECTION_ID is lost + self.assertEqual(drop(server), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] + ) + + # NEW_CONNECTION_ID is retransmitted + server._on_new_connection_id_delivery( + QuicDeliveryState.LOST, server._host_cids[-1] + ) + self.assertEqual(transfer(server, client), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] + ) + + def test_change_connection_id_retransmit_retire_connection_id(self): + with client_and_server() as (client, server): + self.assertEqual( + sequence_numbers(client._peer_cid_available), [1, 2, 3, 4, 5, 6, 7] + ) + + # the client changes connection ID, RETIRE_CONNECTION_ID is lost + client.change_connection_id() + self.assertEqual(drop(client), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7] + ) + + # RETIRE_CONNECTION_ID is retransmitted + client._on_retire_connection_id_delivery(QuicDeliveryState.LOST, 0) + self.assertEqual(transfer(client, server), 1) + + # the server provides a new connection ID + self.assertEqual(transfer(server, client), 1) + self.assertEqual( + sequence_numbers(client._peer_cid_available), [2, 3, 4, 5, 6, 7, 8] + ) + + def test_get_next_available_stream_id(self): + with client_and_server() as (client, server): + # client + stream_id = client.get_next_available_stream_id() + self.assertEqual(stream_id, 0) + client.send_stream_data(stream_id, b"hello") + + stream_id = client.get_next_available_stream_id() + self.assertEqual(stream_id, 4) + client.send_stream_data(stream_id, b"hello") + + stream_id = client.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 2) + client.send_stream_data(stream_id, b"hello") + + stream_id = client.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 6) + client.send_stream_data(stream_id, b"hello") + + # server + stream_id = server.get_next_available_stream_id() + self.assertEqual(stream_id, 1) + server.send_stream_data(stream_id, b"hello") + + stream_id = server.get_next_available_stream_id() + self.assertEqual(stream_id, 5) + server.send_stream_data(stream_id, b"hello") + + stream_id = server.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 3) + server.send_stream_data(stream_id, b"hello") + + stream_id = server.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 7) + server.send_stream_data(stream_id, b"hello") + + def test_datagram_frame(self): + with client_and_server( + client_options={"max_datagram_frame_size": 65536}, + server_options={"max_datagram_frame_size": 65536}, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server, alpn_protocol=None) + + # send datagram + client.send_datagram_frame(b"hello") + self.assertEqual(transfer(client, server), 1) + + event = server.next_event() + self.assertEqual(type(event), events.DatagramFrameReceived) + self.assertEqual(event.data, b"hello") + + def test_datagram_frame_2(self): + # payload which exactly fills an entire packet + payload = b"Z" * 1250 + + with client_and_server( + client_options={"max_datagram_frame_size": 65536}, + server_options={"max_datagram_frame_size": 65536}, + ) as (client, server): + # check handshake completed + self.check_handshake(client=client, server=server, alpn_protocol=None) + + # queue 20 datagrams + for i in range(20): + client.send_datagram_frame(payload) + + # client can only 11 datagrams are sent due to congestion control + self.assertEqual(transfer(client, server), 11) + for i in range(11): + event = server.next_event() + self.assertEqual(type(event), events.DatagramFrameReceived) + self.assertEqual(event.data, payload) + + # server sends ACK + self.assertEqual(transfer(server, client), 1) + + # client sends remaining datagrams + self.assertEqual(transfer(client, server), 9) + for i in range(9): + event = server.next_event() + self.assertEqual(type(event), events.DatagramFrameReceived) + self.assertEqual(event.data, payload) + + def test_decryption_error(self): + with client_and_server() as (client, server): + # mess with encryption key + server._cryptos[tls.Epoch.ONE_RTT].send.setup( + cipher_suite=tls.CipherSuite.AES_128_GCM_SHA256, + secret=bytes(48), + version=server._version, + ) + + # server sends close + server.close(error_code=QuicErrorCode.NO_ERROR) + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) + + def test_tls_error(self): + def patch(client): + real_initialize = client._initialize + + def patched_initialize(peer_cid: bytes): + real_initialize(peer_cid) + client.tls._supported_versions = [tls.TLS_VERSION_1_3_DRAFT_28] + + client._initialize = patched_initialize + + # handshake fails + with client_and_server(client_patch=patch) as (client, server): + timer_at = server.get_timer() + server.handle_timer(timer_at) + + event = server.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, 326) + self.assertEqual(event.frame_type, QuicFrameType.CRYPTO) + self.assertEqual(event.reason_phrase, "No supported protocol version") + + def test_receive_datagram_garbage(self): + client = create_standalone_client(self) + + datagram = binascii.unhexlify("c00000000080") + client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) + + def test_receive_datagram_reserved_bits_non_zero(self): + client = create_standalone_client(self) + + builder = QuicPacketBuilder( + host_cid=client._peer_cid, + is_client=False, + peer_cid=client.host_cid, + version=client._version, + ) + crypto = CryptoPair() + crypto.setup_initial(client._peer_cid, is_client=False, version=client._version) + crypto.encrypt_packet_real = crypto.encrypt_packet + + def encrypt_packet(plain_header, plain_payload, packet_number): + # mess with reserved bits + plain_header = bytes([plain_header[0] | 0x0C]) + plain_header[1:] + return crypto.encrypt_packet_real( + plain_header, plain_payload, packet_number + ) + + crypto.encrypt_packet = encrypt_packet + + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + buf = builder.start_frame(QuicFrameType.PADDING) + buf.push_bytes(bytes(builder.remaining_flight_space)) + + for datagram in builder.flush()[0]: + client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) + self.assertEqual(drop(client), 1) + self.assertEqual( + client._close_event, + events.ConnectionTerminated( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=None, + reason_phrase="Reserved bits must be zero", + ), + ) + + def test_receive_datagram_wrong_version(self): + client = create_standalone_client(self) + + builder = QuicPacketBuilder( + host_cid=client._peer_cid, + is_client=False, + peer_cid=client.host_cid, + version=0xFF000011, # DRAFT_16 + ) + crypto = CryptoPair() + crypto.setup_initial(client._peer_cid, is_client=False, version=client._version) + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + buf = builder.start_frame(QuicFrameType.PADDING) + buf.push_bytes(bytes(builder.remaining_flight_space)) + + for datagram in builder.flush()[0]: + client.receive_datagram(datagram, SERVER_ADDR, now=time.time()) + self.assertEqual(drop(client), 0) + + def test_receive_datagram_retry(self): + client = create_standalone_client(self) + + client.receive_datagram( + encode_quic_retry( + version=client._version, + source_cid=binascii.unhexlify("85abb547bf28be97"), + destination_cid=client.host_cid, + original_destination_cid=client._peer_cid, + retry_token=bytes(16), + ), + SERVER_ADDR, + now=time.time(), + ) + self.assertEqual(drop(client), 1) + + def test_receive_datagram_retry_wrong_destination_cid(self): + client = create_standalone_client(self) + + client.receive_datagram( + encode_quic_retry( + version=client._version, + source_cid=binascii.unhexlify("85abb547bf28be97"), + destination_cid=binascii.unhexlify("c98343fe8f5f0ff4"), + original_destination_cid=client._peer_cid, + retry_token=bytes(16), + ), + SERVER_ADDR, + now=time.time(), + ) + self.assertEqual(drop(client), 0) + + def test_handle_ack_frame_ecn(self): + client = create_standalone_client(self) + + client._handle_ack_frame( + client_receive_context(client), + QuicFrameType.ACK_ECN, + Buffer(data=b"\x00\x02\x00\x00\x00\x00\x00"), + ) + + def test_handle_connection_close_frame(self): + with client_and_server() as (client, server): + server.close( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=QuicFrameType.ACK, + reason_phrase="illegal ACK frame", + ) + self.assertEqual(roundtrip(server, client), (1, 0)) + + self.assertEqual( + client._close_event, + events.ConnectionTerminated( + error_code=QuicErrorCode.PROTOCOL_VIOLATION, + frame_type=QuicFrameType.ACK, + reason_phrase="illegal ACK frame", + ), + ) + + def test_handle_connection_close_frame_app(self): + with client_and_server() as (client, server): + server.close(error_code=QuicErrorCode.NO_ERROR, reason_phrase="goodbye") + self.assertEqual(roundtrip(server, client), (1, 0)) + + self.assertEqual( + client._close_event, + events.ConnectionTerminated( + error_code=QuicErrorCode.NO_ERROR, + frame_type=None, + reason_phrase="goodbye", + ), + ) + + def test_handle_connection_close_frame_app_not_utf8(self): + client = create_standalone_client(self) + + client._handle_connection_close_frame( + client_receive_context(client), + QuicFrameType.APPLICATION_CLOSE, + Buffer(data=binascii.unhexlify("0008676f6f6462798200")), + ) + + self.assertEqual( + client._close_event, + events.ConnectionTerminated( + error_code=QuicErrorCode.NO_ERROR, frame_type=None, reason_phrase="", + ), + ) + + def test_handle_crypto_frame_over_largest_offset(self): + with client_and_server() as (client, server): + # client receives offset + length > 2^62 - 1 + with self.assertRaises(QuicConnectionError) as cm: + client._handle_crypto_frame( + client_receive_context(client), + QuicFrameType.CRYPTO, + Buffer(data=encode_uint_var(UINT_VAR_MAX) + encode_uint_var(1)), + ) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR + ) + self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) + self.assertEqual( + cm.exception.reason_phrase, "offset + length cannot exceed 2^62 - 1" + ) + + def test_handle_data_blocked_frame(self): + with client_and_server() as (client, server): + # client receives DATA_BLOCKED: 12345 + client._handle_data_blocked_frame( + client_receive_context(client), + QuicFrameType.DATA_BLOCKED, + Buffer(data=encode_uint_var(12345)), + ) + + def test_handle_datagram_frame(self): + client = create_standalone_client(self, max_datagram_frame_size=6) + + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM, + Buffer(data=b"hello"), + ) + + self.assertEqual( + client.next_event(), events.DatagramFrameReceived(data=b"hello") + ) + + def test_handle_datagram_frame_not_allowed(self): + client = create_standalone_client(self, max_datagram_frame_size=None) + + with self.assertRaises(QuicConnectionError) as cm: + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM, + Buffer(data=b"hello"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM) + self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") + + def test_handle_datagram_frame_too_large(self): + client = create_standalone_client(self, max_datagram_frame_size=5) + + with self.assertRaises(QuicConnectionError) as cm: + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM, + Buffer(data=b"hello"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM) + self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") + + def test_handle_datagram_frame_with_length(self): + client = create_standalone_client(self, max_datagram_frame_size=7) + + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM_WITH_LENGTH, + Buffer(data=b"\x05hellojunk"), + ) + + self.assertEqual( + client.next_event(), events.DatagramFrameReceived(data=b"hello") + ) + + def test_handle_datagram_frame_with_length_not_allowed(self): + client = create_standalone_client(self, max_datagram_frame_size=None) + + with self.assertRaises(QuicConnectionError) as cm: + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM_WITH_LENGTH, + Buffer(data=b"\x05hellojunk"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM_WITH_LENGTH) + self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") + + def test_handle_datagram_frame_with_length_too_large(self): + client = create_standalone_client(self, max_datagram_frame_size=6) + + with self.assertRaises(QuicConnectionError) as cm: + client._handle_datagram_frame( + client_receive_context(client), + QuicFrameType.DATAGRAM_WITH_LENGTH, + Buffer(data=b"\x05hellojunk"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.DATAGRAM_WITH_LENGTH) + self.assertEqual(cm.exception.reason_phrase, "Unexpected DATAGRAM frame") + + def test_handle_handshake_done_not_allowed(self): + with client_and_server() as (client, server): + # server receives HANDSHAKE_DONE frame + with self.assertRaises(QuicConnectionError) as cm: + server._handle_handshake_done_frame( + client_receive_context(server), + QuicFrameType.HANDSHAKE_DONE, + Buffer(data=b""), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.HANDSHAKE_DONE) + self.assertEqual( + cm.exception.reason_phrase, + "Clients must not send HANDSHAKE_DONE frames", + ) + + def test_handle_max_data_frame(self): + with client_and_server() as (client, server): + self.assertEqual(client._remote_max_data, 1048576) + + # client receives MAX_DATA raising limit + client._handle_max_data_frame( + client_receive_context(client), + QuicFrameType.MAX_DATA, + Buffer(data=encode_uint_var(1048577)), + ) + self.assertEqual(client._remote_max_data, 1048577) + + def test_handle_max_stream_data_frame(self): + with client_and_server() as (client, server): + # client creates bidirectional stream 0 + stream = client._create_stream(stream_id=0) + self.assertEqual(stream.max_stream_data_remote, 1048576) + + # client receives MAX_STREAM_DATA raising limit + client._handle_max_stream_data_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAM_DATA, + Buffer(data=b"\x00" + encode_uint_var(1048577)), + ) + self.assertEqual(stream.max_stream_data_remote, 1048577) + + # client receives MAX_STREAM_DATA lowering limit + client._handle_max_stream_data_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAM_DATA, + Buffer(data=b"\x00" + encode_uint_var(1048575)), + ) + self.assertEqual(stream.max_stream_data_remote, 1048577) + + def test_handle_max_stream_data_frame_receive_only(self): + with client_and_server() as (client, server): + # server creates unidirectional stream 3 + server.send_stream_data(stream_id=3, data=b"hello") + + # client receives MAX_STREAM_DATA: 3, 1 + with self.assertRaises(QuicConnectionError) as cm: + client._handle_max_stream_data_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAM_DATA, + Buffer(data=b"\x03\x01"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.MAX_STREAM_DATA) + self.assertEqual(cm.exception.reason_phrase, "Stream is receive-only") + + def test_handle_max_streams_bidi_frame(self): + with client_and_server() as (client, server): + self.assertEqual(client._remote_max_streams_bidi, 128) + + # client receives MAX_STREAMS_BIDI raising limit + client._handle_max_streams_bidi_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_BIDI, + Buffer(data=encode_uint_var(129)), + ) + self.assertEqual(client._remote_max_streams_bidi, 129) + + # client receives MAX_STREAMS_BIDI lowering limit + client._handle_max_streams_bidi_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_BIDI, + Buffer(data=encode_uint_var(127)), + ) + self.assertEqual(client._remote_max_streams_bidi, 129) + + def test_handle_max_streams_uni_frame(self): + with client_and_server() as (client, server): + self.assertEqual(client._remote_max_streams_uni, 128) + + # client receives MAX_STREAMS_UNI raising limit + client._handle_max_streams_uni_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_UNI, + Buffer(data=encode_uint_var(129)), + ) + self.assertEqual(client._remote_max_streams_uni, 129) + + # client receives MAX_STREAMS_UNI raising limit + client._handle_max_streams_uni_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_UNI, + Buffer(data=encode_uint_var(127)), + ) + self.assertEqual(client._remote_max_streams_uni, 129) + + def test_handle_new_token_frame(self): + with client_and_server() as (client, server): + # client receives NEW_TOKEN + client._handle_new_token_frame( + client_receive_context(client), + QuicFrameType.NEW_TOKEN, + Buffer(data=binascii.unhexlify("080102030405060708")), + ) + + def test_handle_new_token_frame_from_client(self): + with client_and_server() as (client, server): + # server receives NEW_TOKEN + with self.assertRaises(QuicConnectionError) as cm: + server._handle_new_token_frame( + client_receive_context(client), + QuicFrameType.NEW_TOKEN, + Buffer(data=binascii.unhexlify("080102030405060708")), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.NEW_TOKEN) + self.assertEqual( + cm.exception.reason_phrase, "Clients must not send NEW_TOKEN frames" + ) + + def test_handle_path_challenge_frame(self): + with client_and_server() as (client, server): + # client changes address and sends some data + client.send_stream_data(0, b"01234567") + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) + + # check paths + self.assertEqual(len(server._network_paths), 2) + self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) + self.assertFalse(server._network_paths[0].is_validated) + self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) + self.assertTrue(server._network_paths[1].is_validated) + + # server sends PATH_CHALLENGE and receives PATH_RESPONSE + for data, addr in server.datagrams_to_send(now=time.time()): + client.receive_datagram(data, SERVER_ADDR, now=time.time()) + for data, addr in client.datagrams_to_send(now=time.time()): + server.receive_datagram(data, ("1.2.3.4", 2345), now=time.time()) + + # check paths + self.assertEqual(server._network_paths[0].addr, ("1.2.3.4", 2345)) + self.assertTrue(server._network_paths[0].is_validated) + self.assertEqual(server._network_paths[1].addr, ("1.2.3.4", 1234)) + self.assertTrue(server._network_paths[1].is_validated) + + def test_handle_path_response_frame_bad(self): + with client_and_server() as (client, server): + # server receives unsollicited PATH_RESPONSE + with self.assertRaises(QuicConnectionError) as cm: + server._handle_path_response_frame( + client_receive_context(client), + QuicFrameType.PATH_RESPONSE, + Buffer(data=b"\x11\x22\x33\x44\x55\x66\x77\x88"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.PATH_RESPONSE) + + def test_handle_padding_frame(self): + client = create_standalone_client(self) + + # no more padding + buf = Buffer(data=b"") + client._handle_padding_frame( + client_receive_context(client), QuicFrameType.PADDING, buf + ) + self.assertEqual(buf.tell(), 0) + + # padding until end + buf = Buffer(data=bytes(10)) + client._handle_padding_frame( + client_receive_context(client), QuicFrameType.PADDING, buf + ) + self.assertEqual(buf.tell(), 10) + + # padding then something else + buf = Buffer(data=bytes(10) + b"\x01") + client._handle_padding_frame( + client_receive_context(client), QuicFrameType.PADDING, buf + ) + self.assertEqual(buf.tell(), 10) + + def test_handle_reset_stream_frame(self): + with client_and_server() as (client, server): + # client creates bidirectional stream 0 + client.send_stream_data(stream_id=0, data=b"hello") + consume_events(client) + + # client receives RESET_STREAM + client._handle_reset_stream_frame( + client_receive_context(client), + QuicFrameType.RESET_STREAM, + Buffer(data=binascii.unhexlify("000100")), + ) + + event = client.next_event() + self.assertEqual(type(event), events.StreamReset) + self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) + self.assertEqual(event.stream_id, 0) + + def test_handle_reset_stream_frame_send_only(self): + with client_and_server() as (client, server): + # client creates unidirectional stream 2 + client.send_stream_data(stream_id=2, data=b"hello") + + # client receives RESET_STREAM + with self.assertRaises(QuicConnectionError) as cm: + client._handle_reset_stream_frame( + client_receive_context(client), + QuicFrameType.RESET_STREAM, + Buffer(data=binascii.unhexlify("021100")), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.RESET_STREAM) + self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") + + def test_handle_retire_connection_id_frame(self): + with client_and_server() as (client, server): + self.assertEqual( + sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] + ) + + # client receives RETIRE_CONNECTION_ID + client._handle_retire_connection_id_frame( + client_receive_context(client), + QuicFrameType.RETIRE_CONNECTION_ID, + Buffer(data=b"\x02"), + ) + self.assertEqual( + sequence_numbers(client._host_cids), [0, 1, 3, 4, 5, 6, 7, 8] + ) + + def test_handle_retire_connection_id_frame_current_cid(self): + with client_and_server() as (client, server): + self.assertEqual( + sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] + ) + + # client receives RETIRE_CONNECTION_ID for the current CID + with self.assertRaises(QuicConnectionError) as cm: + client._handle_retire_connection_id_frame( + client_receive_context(client), + QuicFrameType.RETIRE_CONNECTION_ID, + Buffer(data=b"\x00"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual( + cm.exception.frame_type, QuicFrameType.RETIRE_CONNECTION_ID + ) + self.assertEqual( + cm.exception.reason_phrase, "Cannot retire current connection ID" + ) + self.assertEqual( + sequence_numbers(client._host_cids), [0, 1, 2, 3, 4, 5, 6, 7] + ) + + def test_handle_stop_sending_frame(self): + with client_and_server() as (client, server): + # client creates bidirectional stream 0 + client.send_stream_data(stream_id=0, data=b"hello") + + # client receives STOP_SENDING + client._handle_stop_sending_frame( + client_receive_context(client), + QuicFrameType.STOP_SENDING, + Buffer(data=b"\x00\x11"), + ) + + def test_handle_stop_sending_frame_receive_only(self): + with client_and_server() as (client, server): + # server creates unidirectional stream 3 + server.send_stream_data(stream_id=3, data=b"hello") + + # client receives STOP_SENDING + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stop_sending_frame( + client_receive_context(client), + QuicFrameType.STOP_SENDING, + Buffer(data=b"\x03\x11"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.STOP_SENDING) + self.assertEqual(cm.exception.reason_phrase, "Stream is receive-only") + + def test_handle_stream_frame_over_largest_offset(self): + with client_and_server() as (client, server): + # client receives offset + length > 2^62 - 1 + frame_type = QuicFrameType.STREAM_BASE | 6 + stream_id = 1 + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + frame_type, + Buffer( + data=encode_uint_var(stream_id) + + encode_uint_var(UINT_VAR_MAX) + + encode_uint_var(1) + ), + ) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR + ) + self.assertEqual(cm.exception.frame_type, frame_type) + self.assertEqual( + cm.exception.reason_phrase, "offset + length cannot exceed 2^62 - 1" + ) + + def test_handle_stream_frame_over_max_data(self): + with client_and_server() as (client, server): + # artificially raise received data counter + client._local_max_data_used = client._local_max_data + + # client receives STREAM frame + frame_type = QuicFrameType.STREAM_BASE | 4 + stream_id = 1 + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + frame_type, + Buffer(data=encode_uint_var(stream_id) + encode_uint_var(1)), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) + self.assertEqual(cm.exception.frame_type, frame_type) + self.assertEqual(cm.exception.reason_phrase, "Over connection data limit") + + def test_handle_stream_frame_over_max_stream_data(self): + with client_and_server() as (client, server): + # client receives STREAM frame + frame_type = QuicFrameType.STREAM_BASE | 4 + stream_id = 1 + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + frame_type, + Buffer( + data=encode_uint_var(stream_id) + + encode_uint_var(client._local_max_stream_data_bidi_remote + 1) + ), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.FLOW_CONTROL_ERROR) + self.assertEqual(cm.exception.frame_type, frame_type) + self.assertEqual(cm.exception.reason_phrase, "Over stream data limit") + + def test_handle_stream_frame_over_max_streams(self): + with client_and_server() as (client, server): + # client receives STREAM frame + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + QuicFrameType.STREAM_BASE, + Buffer( + data=encode_uint_var(client._local_max_stream_data_uni * 4 + 3) + ), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_LIMIT_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) + self.assertEqual(cm.exception.reason_phrase, "Too many streams open") + + def test_handle_stream_frame_send_only(self): + with client_and_server() as (client, server): + # client creates unidirectional stream 2 + client.send_stream_data(stream_id=2, data=b"hello") + + # client receives STREAM frame + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + QuicFrameType.STREAM_BASE, + Buffer(data=b"\x02"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) + self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") + + def test_handle_stream_frame_wrong_initiator(self): + with client_and_server() as (client, server): + # client receives STREAM frame + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_frame( + client_receive_context(client), + QuicFrameType.STREAM_BASE, + Buffer(data=b"\x00"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_BASE) + self.assertEqual(cm.exception.reason_phrase, "Wrong stream initiator") + + def test_handle_stream_data_blocked_frame(self): + with client_and_server() as (client, server): + # client creates bidirectional stream 0 + client.send_stream_data(stream_id=0, data=b"hello") + + # client receives STREAM_DATA_BLOCKED + client._handle_stream_data_blocked_frame( + client_receive_context(client), + QuicFrameType.STREAM_DATA_BLOCKED, + Buffer(data=b"\x00\x01"), + ) + + def test_handle_stream_data_blocked_frame_send_only(self): + with client_and_server() as (client, server): + # client creates unidirectional stream 2 + client.send_stream_data(stream_id=2, data=b"hello") + + # client receives STREAM_DATA_BLOCKED + with self.assertRaises(QuicConnectionError) as cm: + client._handle_stream_data_blocked_frame( + client_receive_context(client), + QuicFrameType.STREAM_DATA_BLOCKED, + Buffer(data=b"\x02\x01"), + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.STREAM_STATE_ERROR) + self.assertEqual(cm.exception.frame_type, QuicFrameType.STREAM_DATA_BLOCKED) + self.assertEqual(cm.exception.reason_phrase, "Stream is send-only") + + def test_handle_streams_blocked_uni_frame(self): + with client_and_server() as (client, server): + # client receives STREAMS_BLOCKED_UNI: 0 + client._handle_streams_blocked_frame( + client_receive_context(client), + QuicFrameType.STREAMS_BLOCKED_UNI, + Buffer(data=b"\x00"), + ) + + def test_payload_received_padding_only(self): + with client_and_server() as (client, server): + # client receives padding only + is_ack_eliciting, is_probing = client._payload_received( + client_receive_context(client), b"\x00" * 1200 + ) + self.assertFalse(is_ack_eliciting) + self.assertTrue(is_probing) + + def test_payload_received_unknown_frame(self): + with client_and_server() as (client, server): + # client receives unknown frame + with self.assertRaises(QuicConnectionError) as cm: + client._payload_received(client_receive_context(client), b"\x1f") + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, 0x1F) + self.assertEqual(cm.exception.reason_phrase, "Unknown frame type") + + def test_payload_received_unexpected_frame(self): + with client_and_server() as (client, server): + # client receives CRYPTO frame in 0-RTT + with self.assertRaises(QuicConnectionError) as cm: + client._payload_received( + client_receive_context(client, epoch=tls.Epoch.ZERO_RTT), b"\x06" + ) + self.assertEqual(cm.exception.error_code, QuicErrorCode.PROTOCOL_VIOLATION) + self.assertEqual(cm.exception.frame_type, QuicFrameType.CRYPTO) + self.assertEqual(cm.exception.reason_phrase, "Unexpected frame type") + + def test_payload_received_malformed_frame(self): + with client_and_server() as (client, server): + # client receives malformed TRANSPORT_CLOSE frame + with self.assertRaises(QuicConnectionError) as cm: + client._payload_received( + client_receive_context(client), b"\x1c\x00\x01" + ) + self.assertEqual( + cm.exception.error_code, QuicErrorCode.FRAME_ENCODING_ERROR + ) + self.assertEqual(cm.exception.frame_type, 0x1C) + self.assertEqual(cm.exception.reason_phrase, "Failed to parse frame") + + def test_send_max_data_blocked_by_cc(self): + with client_and_server() as (client, server): + # check congestion control + self.assertEqual(client._loss.bytes_in_flight, 0) + self.assertEqual(client._loss.congestion_window, 14303) + + # artificially raise received data counter + client._local_max_data_used = client._local_max_data + self.assertEqual(server._remote_max_data, 1048576) + + # artificially raise bytes in flight + client._loss._cc.bytes_in_flight = 14303 + + # MAX_DATA is not sent due to congestion control + self.assertEqual(drop(client), 0) + + def test_send_max_data_retransmit(self): + with client_and_server() as (client, server): + # artificially raise received data counter + client._local_max_data_used = client._local_max_data + self.assertEqual(server._remote_max_data, 1048576) + + # MAX_DATA is sent and lost + self.assertEqual(drop(client), 1) + self.assertEqual(client._local_max_data_sent, 2097152) + self.assertEqual(server._remote_max_data, 1048576) + + # MAX_DATA is retransmitted and acked + client._on_max_data_delivery(QuicDeliveryState.LOST) + self.assertEqual(client._local_max_data_sent, 0) + self.assertEqual(roundtrip(client, server), (1, 1)) + self.assertEqual(server._remote_max_data, 2097152) + + def test_send_max_stream_data_retransmit(self): + with client_and_server() as (client, server): + # client creates bidirectional stream 0 + stream = client._create_stream(stream_id=0) + client.send_stream_data(0, b"hello") + self.assertEqual(stream.max_stream_data_local, 1048576) + self.assertEqual(stream.max_stream_data_local_sent, 1048576) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # server sends data, just before raising MAX_STREAM_DATA + server.send_stream_data(0, b"Z" * 524288) # 1048576 // 2 + for i in range(10): + roundtrip(server, client) + self.assertEqual(stream.max_stream_data_local, 1048576) + self.assertEqual(stream.max_stream_data_local_sent, 1048576) + + # server sends one more byte + server.send_stream_data(0, b"Z") + self.assertEqual(transfer(server, client), 1) + + # MAX_STREAM_DATA is sent and lost + self.assertEqual(drop(client), 1) + self.assertEqual(stream.max_stream_data_local, 2097152) + self.assertEqual(stream.max_stream_data_local_sent, 2097152) + client._on_max_stream_data_delivery(QuicDeliveryState.LOST, stream) + self.assertEqual(stream.max_stream_data_local, 2097152) + self.assertEqual(stream.max_stream_data_local_sent, 0) + + # MAX_DATA is retransmitted and acked + self.assertEqual(roundtrip(client, server), (1, 1)) + self.assertEqual(stream.max_stream_data_local, 2097152) + self.assertEqual(stream.max_stream_data_local_sent, 2097152) + + def test_send_ping(self): + with client_and_server() as (client, server): + consume_events(client) + + # client sends ping, server ACKs it + client.send_ping(uid=12345) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # check event + event = client.next_event() + self.assertEqual(type(event), events.PingAcknowledged) + self.assertEqual(event.uid, 12345) + + def test_send_ping_retransmit(self): + with client_and_server() as (client, server): + consume_events(client) + + # client sends another ping, PING is lost + client.send_ping(uid=12345) + self.assertEqual(drop(client), 1) + + # PING is retransmitted and acked + client._on_ping_delivery(QuicDeliveryState.LOST, (12345,)) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # check event + event = client.next_event() + self.assertEqual(type(event), events.PingAcknowledged) + self.assertEqual(event.uid, 12345) + + def test_send_stream_data_over_max_streams_bidi(self): + with client_and_server() as (client, server): + # create streams + for i in range(128): + stream_id = i * 4 + client.send_stream_data(stream_id, b"") + self.assertFalse(client._streams[stream_id].is_blocked) + self.assertEqual(len(client._streams_blocked_bidi), 0) + self.assertEqual(len(client._streams_blocked_uni), 0) + self.assertEqual(roundtrip(client, server), (0, 0)) + + # create one too many -> STREAMS_BLOCKED + stream_id = 128 * 4 + client.send_stream_data(stream_id, b"") + self.assertTrue(client._streams[stream_id].is_blocked) + self.assertEqual(len(client._streams_blocked_bidi), 1) + self.assertEqual(len(client._streams_blocked_uni), 0) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # peer raises max streams + client._handle_max_streams_bidi_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_BIDI, + Buffer(data=encode_uint_var(129)), + ) + self.assertFalse(client._streams[stream_id].is_blocked) + + def test_send_stream_data_over_max_streams_uni(self): + with client_and_server() as (client, server): + # create streams + for i in range(128): + stream_id = i * 4 + 2 + client.send_stream_data(stream_id, b"") + self.assertFalse(client._streams[stream_id].is_blocked) + self.assertEqual(len(client._streams_blocked_bidi), 0) + self.assertEqual(len(client._streams_blocked_uni), 0) + self.assertEqual(roundtrip(client, server), (0, 0)) + + # create one too many -> STREAMS_BLOCKED + stream_id = 128 * 4 + 2 + client.send_stream_data(stream_id, b"") + self.assertTrue(client._streams[stream_id].is_blocked) + self.assertEqual(len(client._streams_blocked_bidi), 0) + self.assertEqual(len(client._streams_blocked_uni), 1) + self.assertEqual(roundtrip(client, server), (1, 1)) + + # peer raises max streams + client._handle_max_streams_uni_frame( + client_receive_context(client), + QuicFrameType.MAX_STREAMS_UNI, + Buffer(data=encode_uint_var(129)), + ) + self.assertFalse(client._streams[stream_id].is_blocked) + + def test_send_stream_data_peer_initiated(self): + with client_and_server() as (client, server): + # server creates bidirectional stream + server.send_stream_data(1, b"hello") + self.assertEqual(roundtrip(server, client), (1, 1)) + + # server creates unidirectional stream + server.send_stream_data(3, b"hello") + self.assertEqual(roundtrip(server, client), (1, 1)) + + # client creates bidirectional stream + client.send_stream_data(0, b"hello") + self.assertEqual(roundtrip(client, server), (1, 1)) + + # client sends data on server-initiated bidirectional stream + client.send_stream_data(1, b"hello") + self.assertEqual(roundtrip(client, server), (1, 1)) + + # client create unidirectional stream + client.send_stream_data(2, b"hello") + self.assertEqual(roundtrip(client, server), (1, 1)) + + # client tries to send data on server-initial unidirectional stream + with self.assertRaises(ValueError) as cm: + client.send_stream_data(3, b"hello") + self.assertEqual( + str(cm.exception), + "Cannot send data on peer-initiated unidirectional stream", + ) + + # client tries to send data on unknown server-initiated bidirectional stream + with self.assertRaises(ValueError) as cm: + client.send_stream_data(5, b"hello") + self.assertEqual( + str(cm.exception), "Cannot send data on unknown peer-initiated stream" + ) + + def test_stream_direction(self): + with client_and_server() as (client, server): + for off in [0, 4, 8]: + # Client-Initiated, Bidirectional + self.assertTrue(client._stream_can_receive(off)) + self.assertTrue(client._stream_can_send(off)) + self.assertTrue(server._stream_can_receive(off)) + self.assertTrue(server._stream_can_send(off)) + + # Server-Initiated, Bidirectional + self.assertTrue(client._stream_can_receive(off + 1)) + self.assertTrue(client._stream_can_send(off + 1)) + self.assertTrue(server._stream_can_receive(off + 1)) + self.assertTrue(server._stream_can_send(off + 1)) + + # Client-Initiated, Unidirectional + self.assertFalse(client._stream_can_receive(off + 2)) + self.assertTrue(client._stream_can_send(off + 2)) + self.assertTrue(server._stream_can_receive(off + 2)) + self.assertFalse(server._stream_can_send(off + 2)) + + # Server-Initiated, Unidirectional + self.assertTrue(client._stream_can_receive(off + 3)) + self.assertFalse(client._stream_can_send(off + 3)) + self.assertFalse(server._stream_can_receive(off + 3)) + self.assertTrue(server._stream_can_send(off + 3)) + + def test_version_negotiation_fail(self): + client = create_standalone_client(self) + + # no common version, no retry + client.receive_datagram( + encode_quic_version_negotiation( + source_cid=client._peer_cid, + destination_cid=client.host_cid, + supported_versions=[0xFF000011], # DRAFT_16 + ), + SERVER_ADDR, + now=time.time(), + ) + self.assertEqual(drop(client), 0) + + event = client.next_event() + self.assertEqual(type(event), events.ConnectionTerminated) + self.assertEqual(event.error_code, QuicErrorCode.INTERNAL_ERROR) + self.assertEqual(event.frame_type, None) + self.assertEqual( + event.reason_phrase, "Could not find a common protocol version" + ) + + def test_version_negotiation_ok(self): + client = create_standalone_client(self) + + # found a common version, retry + client.receive_datagram( + encode_quic_version_negotiation( + source_cid=client._peer_cid, + destination_cid=client.host_cid, + supported_versions=[client._version], + ), + SERVER_ADDR, + now=time.time(), + ) + self.assertEqual(drop(client), 1) + + +class QuicNetworkPathTest(TestCase): + def test_can_send(self): + path = QuicNetworkPath(("1.2.3.4", 1234)) + self.assertFalse(path.is_validated) + + # initially, cannot send any data + self.assertTrue(path.can_send(0)) + self.assertFalse(path.can_send(1)) + + # receive some data + path.bytes_received += 1 + self.assertTrue(path.can_send(0)) + self.assertTrue(path.can_send(1)) + self.assertTrue(path.can_send(2)) + self.assertTrue(path.can_send(3)) + self.assertFalse(path.can_send(4)) + + # send some data + path.bytes_sent += 3 + self.assertTrue(path.can_send(0)) + self.assertFalse(path.can_send(1)) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_crypto.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_crypto.py new file mode 100644 index 000000000000..424e4dcf5e18 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_crypto.py @@ -0,0 +1,319 @@ +import binascii +from unittest import TestCase, skipIf + +from aioquic.buffer import Buffer +from aioquic.quic.crypto import ( + INITIAL_CIPHER_SUITE, + CryptoError, + CryptoPair, + derive_key_iv_hp, +) +from aioquic.quic.packet import PACKET_FIXED_BIT, QuicProtocolVersion +from aioquic.tls import CipherSuite + +from .utils import SKIP_TESTS + +PROTOCOL_VERSION = QuicProtocolVersion.DRAFT_25 + +CHACHA20_CLIENT_PACKET_NUMBER = 2 +CHACHA20_CLIENT_PLAIN_HEADER = binascii.unhexlify( + "e1ff0000160880b57c7b70d8524b0850fc2a28e240fd7640170002" +) +CHACHA20_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify("0201000000") +CHACHA20_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( + "e8ff0000160880b57c7b70d8524b0850fc2a28e240fd7640178313b04be98449" + "eb10567e25ce930381f2a5b7da2db8db" +) + +LONG_CLIENT_PACKET_NUMBER = 2 +LONG_CLIENT_PLAIN_HEADER = binascii.unhexlify( + "c3ff000017088394c8f03e5157080000449e00000002" +) +LONG_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify( + "060040c4010000c003036660261ff947cea49cce6cfad687f457cf1b14531ba1" + "4131a0e8f309a1d0b9c4000006130113031302010000910000000b0009000006" + "736572766572ff01000100000a00140012001d00170018001901000101010201" + "03010400230000003300260024001d00204cfdfcd178b784bf328cae793b136f" + "2aedce005ff183d7bb1495207236647037002b0003020304000d0020001e0403" + "05030603020308040805080604010501060102010402050206020202002d0002" + "0101001c00024001" +) + bytes(962) +LONG_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify( + "c0ff000017088394c8f03e5157080000449e3b343aa8535064a4268a0d9d7b1c" + "9d250ae355162276e9b1e3011ef6bbc0ab48ad5bcc2681e953857ca62becd752" + "4daac473e68d7405fbba4e9ee616c87038bdbe908c06d9605d9ac49030359eec" + "b1d05a14e117db8cede2bb09d0dbbfee271cb374d8f10abec82d0f59a1dee29f" + "e95638ed8dd41da07487468791b719c55c46968eb3b54680037102a28e53dc1d" + "12903db0af5821794b41c4a93357fa59ce69cfe7f6bdfa629eef78616447e1d6" + "11c4baf71bf33febcb03137c2c75d25317d3e13b684370f668411c0f00304b50" + "1c8fd422bd9b9ad81d643b20da89ca0525d24d2b142041cae0af205092e43008" + "0cd8559ea4c5c6e4fa3f66082b7d303e52ce0162baa958532b0bbc2bc785681f" + "cf37485dff6595e01e739c8ac9efba31b985d5f656cc092432d781db95221724" + "87641c4d3ab8ece01e39bc85b15436614775a98ba8fa12d46f9b35e2a55eb72d" + "7f85181a366663387ddc20551807e007673bd7e26bf9b29b5ab10a1ca87cbb7a" + "d97e99eb66959c2a9bc3cbde4707ff7720b110fa95354674e395812e47a0ae53" + "b464dcb2d1f345df360dc227270c750676f6724eb479f0d2fbb6124429990457" + "ac6c9167f40aab739998f38b9eccb24fd47c8410131bf65a52af841275d5b3d1" + "880b197df2b5dea3e6de56ebce3ffb6e9277a82082f8d9677a6767089b671ebd" + "244c214f0bde95c2beb02cd1172d58bdf39dce56ff68eb35ab39b49b4eac7c81" + "5ea60451d6e6ab82119118df02a586844a9ffe162ba006d0669ef57668cab38b" + "62f71a2523a084852cd1d079b3658dc2f3e87949b550bab3e177cfc49ed190df" + "f0630e43077c30de8f6ae081537f1e83da537da980afa668e7b7fb25301cf741" + "524be3c49884b42821f17552fbd1931a813017b6b6590a41ea18b6ba49cd48a4" + "40bd9a3346a7623fb4ba34a3ee571e3c731f35a7a3cf25b551a680fa68763507" + "b7fde3aaf023c50b9d22da6876ba337eb5e9dd9ec3daf970242b6c5aab3aa4b2" + "96ad8b9f6832f686ef70fa938b31b4e5ddd7364442d3ea72e73d668fb0937796" + "f462923a81a47e1cee7426ff6d9221269b5a62ec03d6ec94d12606cb485560ba" + "b574816009e96504249385bb61a819be04f62c2066214d8360a2022beb316240" + "b6c7d78bbe56c13082e0ca272661210abf020bf3b5783f1426436cf9ff418405" + "93a5d0638d32fc51c5c65ff291a3a7a52fd6775e623a4439cc08dd25582febc9" + "44ef92d8dbd329c91de3e9c9582e41f17f3d186f104ad3f90995116c682a2a14" + "a3b4b1f547c335f0be710fc9fc03e0e587b8cda31ce65b969878a4ad4283e6d5" + "b0373f43da86e9e0ffe1ae0fddd3516255bd74566f36a38703d5f34249ded1f6" + "6b3d9b45b9af2ccfefe984e13376b1b2c6404aa48c8026132343da3f3a33659e" + "c1b3e95080540b28b7f3fcd35fa5d843b579a84c089121a60d8c1754915c344e" + "eaf45a9bf27dc0c1e78416169122091313eb0e87555abd706626e557fc36a04f" + "cd191a58829104d6075c5594f627ca506bf181daec940f4a4f3af0074eee89da" + "acde6758312622d4fa675b39f728e062d2bee680d8f41a597c262648bb18bcfc" + "13c8b3d97b1a77b2ac3af745d61a34cc4709865bac824a94bb19058015e4e42d" + "c9be6c7803567321829dd85853396269" +) + +LONG_SERVER_PACKET_NUMBER = 1 +LONG_SERVER_PLAIN_HEADER = binascii.unhexlify( + "c1ff0000170008f067a5502a4262b50040740001" +) +LONG_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( + "0d0000000018410a020000560303eefce7f7b37ba1d1632e96677825ddf73988" + "cfc79825df566dc5430b9a045a1200130100002e00330024001d00209d3c940d" + "89690b84d08a60993c144eca684d1081287c834d5311bcf32bb9da1a002b0002" + "0304" +) +LONG_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( + "c9ff0000170008f067a5502a4262b5004074168bf22b7002596f99ae67abf65a" + "5852f54f58c37c808682e2e40492d8a3899fb04fc0afe9aabc8767b18a0aa493" + "537426373b48d502214dd856d63b78cee37bc664b3fe86d487ac7a77c53038a3" + "cd32f0b5004d9f5754c4f7f2d1f35cf3f7116351c92b9cf9bb6d091ddfc8b32d" + "432348a2c413" +) + +SHORT_SERVER_PACKET_NUMBER = 3 +SHORT_SERVER_PLAIN_HEADER = binascii.unhexlify("41b01fd24a586a9cf30003") +SHORT_SERVER_PLAIN_PAYLOAD = binascii.unhexlify( + "06003904000035000151805a4bebf5000020b098c8dc4183e4c182572e10ac3e" + "2b88897e0524c8461847548bd2dffa2c0ae60008002a0004ffffffff" +) +SHORT_SERVER_ENCRYPTED_PACKET = binascii.unhexlify( + "5db01fd24a586a9cf33dec094aaec6d6b4b7a5e15f5a3f05d06cf1ad0355c19d" + "cce0807eecf7bf1c844a66e1ecd1f74b2a2d69bfd25d217833edd973246597bd" + "5107ea15cb1e210045396afa602fe23432f4ab24ce251b" +) + + +class CryptoTest(TestCase): + """ + Test vectors from: + + https://tools.ietf.org/html/draft-ietf-quic-tls-18#appendix-A + """ + + def create_crypto(self, is_client): + pair = CryptoPair() + pair.setup_initial( + cid=binascii.unhexlify("8394c8f03e515708"), + is_client=is_client, + version=PROTOCOL_VERSION, + ) + return pair + + def test_derive_key_iv_hp(self): + # client + secret = binascii.unhexlify( + "8a3515a14ae3c31b9c2d6d5bc58538ca5cd2baa119087143e60887428dcb52f6" + ) + key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret) + self.assertEqual(key, binascii.unhexlify("98b0d7e5e7a402c67c33f350fa65ea54")) + self.assertEqual(iv, binascii.unhexlify("19e94387805eb0b46c03a788")) + self.assertEqual(hp, binascii.unhexlify("0edd982a6ac527f2eddcbb7348dea5d7")) + + # server + secret = binascii.unhexlify( + "47b2eaea6c266e32c0697a9e2a898bdf5c4fb3e5ac34f0e549bf2c58581a3811" + ) + key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret) + self.assertEqual(key, binascii.unhexlify("9a8be902a9bdd91d16064ca118045fb4")) + self.assertEqual(iv, binascii.unhexlify("0a82086d32205ba22241d8dc")) + self.assertEqual(hp, binascii.unhexlify("94b9452d2b3c7c7f6da7fdd8593537fd")) + + @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") + def test_decrypt_chacha20(self): + pair = CryptoPair() + pair.recv.setup( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=binascii.unhexlify( + "b42772df33c9719a32820d302aa664d080d7f5ea7a71a330f87864cb289ae8c0" + ), + version=PROTOCOL_VERSION, + ) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + CHACHA20_CLIENT_ENCRYPTED_PACKET, 25, 0 + ) + self.assertEqual(plain_header, CHACHA20_CLIENT_PLAIN_HEADER) + self.assertEqual(plain_payload, CHACHA20_CLIENT_PLAIN_PAYLOAD) + self.assertEqual(packet_number, CHACHA20_CLIENT_PACKET_NUMBER) + + def test_decrypt_long_client(self): + pair = self.create_crypto(is_client=False) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + LONG_CLIENT_ENCRYPTED_PACKET, 18, 0 + ) + self.assertEqual(plain_header, LONG_CLIENT_PLAIN_HEADER) + self.assertEqual(plain_payload, LONG_CLIENT_PLAIN_PAYLOAD) + self.assertEqual(packet_number, LONG_CLIENT_PACKET_NUMBER) + + def test_decrypt_long_server(self): + pair = self.create_crypto(is_client=True) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + LONG_SERVER_ENCRYPTED_PACKET, 18, 0 + ) + self.assertEqual(plain_header, LONG_SERVER_PLAIN_HEADER) + self.assertEqual(plain_payload, LONG_SERVER_PLAIN_PAYLOAD) + self.assertEqual(packet_number, LONG_SERVER_PACKET_NUMBER) + + def test_decrypt_no_key(self): + pair = CryptoPair() + with self.assertRaises(CryptoError): + pair.decrypt_packet(LONG_SERVER_ENCRYPTED_PACKET, 18, 0) + + def test_decrypt_short_server(self): + pair = CryptoPair() + pair.recv.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=binascii.unhexlify( + "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" + ), + version=PROTOCOL_VERSION, + ) + + plain_header, plain_payload, packet_number = pair.decrypt_packet( + SHORT_SERVER_ENCRYPTED_PACKET, 9, 0 + ) + self.assertEqual(plain_header, SHORT_SERVER_PLAIN_HEADER) + self.assertEqual(plain_payload, SHORT_SERVER_PLAIN_PAYLOAD) + self.assertEqual(packet_number, SHORT_SERVER_PACKET_NUMBER) + + @skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests") + def test_encrypt_chacha20(self): + pair = CryptoPair() + pair.send.setup( + cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256, + secret=binascii.unhexlify( + "b42772df33c9719a32820d302aa664d080d7f5ea7a71a330f87864cb289ae8c0" + ), + version=PROTOCOL_VERSION, + ) + + packet = pair.encrypt_packet( + CHACHA20_CLIENT_PLAIN_HEADER, + CHACHA20_CLIENT_PLAIN_PAYLOAD, + CHACHA20_CLIENT_PACKET_NUMBER, + ) + self.assertEqual(packet, CHACHA20_CLIENT_ENCRYPTED_PACKET) + + def test_encrypt_long_client(self): + pair = self.create_crypto(is_client=True) + + packet = pair.encrypt_packet( + LONG_CLIENT_PLAIN_HEADER, + LONG_CLIENT_PLAIN_PAYLOAD, + LONG_CLIENT_PACKET_NUMBER, + ) + self.assertEqual(packet, LONG_CLIENT_ENCRYPTED_PACKET) + + def test_encrypt_long_server(self): + pair = self.create_crypto(is_client=False) + + packet = pair.encrypt_packet( + LONG_SERVER_PLAIN_HEADER, + LONG_SERVER_PLAIN_PAYLOAD, + LONG_SERVER_PACKET_NUMBER, + ) + self.assertEqual(packet, LONG_SERVER_ENCRYPTED_PACKET) + + def test_encrypt_short_server(self): + pair = CryptoPair() + pair.send.setup( + cipher_suite=INITIAL_CIPHER_SUITE, + secret=binascii.unhexlify( + "310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100" + ), + version=PROTOCOL_VERSION, + ) + + packet = pair.encrypt_packet( + SHORT_SERVER_PLAIN_HEADER, + SHORT_SERVER_PLAIN_PAYLOAD, + SHORT_SERVER_PACKET_NUMBER, + ) + self.assertEqual(packet, SHORT_SERVER_ENCRYPTED_PACKET) + + def test_key_update(self): + pair1 = self.create_crypto(is_client=True) + pair2 = self.create_crypto(is_client=False) + + def create_packet(key_phase, packet_number): + buf = Buffer(capacity=100) + buf.push_uint8(PACKET_FIXED_BIT | key_phase << 2 | 1) + buf.push_bytes(binascii.unhexlify("8394c8f03e515708")) + buf.push_uint16(packet_number) + return buf.data, b"\x00\x01\x02\x03" + + def send(sender, receiver, packet_number=0): + plain_header, plain_payload = create_packet( + key_phase=sender.key_phase, packet_number=packet_number + ) + encrypted = sender.encrypt_packet( + plain_header, plain_payload, packet_number + ) + recov_header, recov_payload, recov_packet_number = receiver.decrypt_packet( + encrypted, len(plain_header) - 2, 0 + ) + self.assertEqual(recov_header, plain_header) + self.assertEqual(recov_payload, plain_payload) + self.assertEqual(recov_packet_number, packet_number) + + # roundtrip + send(pair1, pair2, 0) + send(pair2, pair1, 0) + self.assertEqual(pair1.key_phase, 0) + self.assertEqual(pair2.key_phase, 0) + + # pair 1 key update + pair1.update_key() + + # roundtrip + send(pair1, pair2, 1) + send(pair2, pair1, 1) + self.assertEqual(pair1.key_phase, 1) + self.assertEqual(pair2.key_phase, 1) + + # pair 2 key update + pair2.update_key() + + # roundtrip + send(pair2, pair1, 2) + send(pair1, pair2, 2) + self.assertEqual(pair1.key_phase, 0) + self.assertEqual(pair2.key_phase, 0) + + # pair 1 key - update, but not next to send + pair1.update_key() + + # roundtrip + send(pair2, pair1, 3) + send(pair1, pair2, 3) + self.assertEqual(pair1.key_phase, 1) + self.assertEqual(pair2.key_phase, 1) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h0.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h0.py new file mode 100644 index 000000000000..81ccf65e1174 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h0.py @@ -0,0 +1,148 @@ +from unittest import TestCase + +from aioquic.h0.connection import H0_ALPN, H0Connection +from aioquic.h3.events import DataReceived, HeadersReceived + +from .test_connection import client_and_server, transfer + + +def h0_client_and_server(): + return client_and_server( + client_options={"alpn_protocols": H0_ALPN}, + server_options={"alpn_protocols": H0_ALPN}, + ) + + +def h0_transfer(quic_sender, h0_receiver): + quic_receiver = h0_receiver._quic + transfer(quic_sender, quic_receiver) + + # process QUIC events + http_events = [] + event = quic_receiver.next_event() + while event is not None: + http_events.extend(h0_receiver.handle_event(event)) + event = quic_receiver.next_event() + return http_events + + +class H0ConnectionTest(TestCase): + def test_connect(self): + with h0_client_and_server() as (quic_client, quic_server): + h0_client = H0Connection(quic_client) + h0_server = H0Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h0_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + ) + h0_client.send_data(stream_id=stream_id, data=b"", end_stream=True) + + # receive request + events = h0_transfer(quic_client, h0_server) + self.assertEqual(len(events), 2) + + self.assertTrue(isinstance(events[0], HeadersReceived)) + self.assertEqual( + events[0].headers, [(b":method", b"GET"), (b":path", b"/")] + ) + self.assertEqual(events[0].stream_id, stream_id) + self.assertEqual(events[0].stream_ended, False) + + self.assertTrue(isinstance(events[1], DataReceived)) + self.assertEqual(events[1].data, b"") + self.assertEqual(events[1].stream_id, stream_id) + self.assertEqual(events[1].stream_ended, True) + + # send response + h0_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + ) + h0_server.send_data( + stream_id=stream_id, + data=b"hello", + end_stream=True, + ) + + # receive response + events = h0_transfer(quic_server, h0_client) + self.assertEqual(len(events), 2) + + self.assertTrue(isinstance(events[0], HeadersReceived)) + self.assertEqual(events[0].headers, []) + self.assertEqual(events[0].stream_id, stream_id) + self.assertEqual(events[0].stream_ended, False) + + self.assertTrue(isinstance(events[1], DataReceived)) + self.assertEqual(events[1].data, b"hello") + self.assertEqual(events[1].stream_id, stream_id) + self.assertEqual(events[1].stream_ended, True) + + def test_headers_only(self): + with h0_client_and_server() as (quic_client, quic_server): + h0_client = H0Connection(quic_client) + h0_server = H0Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h0_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"HEAD"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + end_stream=True, + ) + + # receive request + events = h0_transfer(quic_client, h0_server) + self.assertEqual(len(events), 2) + + self.assertTrue(isinstance(events[0], HeadersReceived)) + self.assertEqual( + events[0].headers, [(b":method", b"HEAD"), (b":path", b"/")] + ) + self.assertEqual(events[0].stream_id, stream_id) + self.assertEqual(events[0].stream_ended, False) + + self.assertTrue(isinstance(events[1], DataReceived)) + self.assertEqual(events[1].data, b"") + self.assertEqual(events[1].stream_id, stream_id) + self.assertEqual(events[1].stream_ended, True) + + # send response + h0_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + end_stream=True, + ) + + # receive response + events = h0_transfer(quic_server, h0_client) + self.assertEqual(len(events), 2) + + self.assertTrue(isinstance(events[0], HeadersReceived)) + self.assertEqual(events[0].headers, []) + self.assertEqual(events[0].stream_id, stream_id) + self.assertEqual(events[0].stream_ended, False) + + self.assertTrue(isinstance(events[1], DataReceived)) + self.assertEqual(events[1].data, b"") + self.assertEqual(events[1].stream_id, stream_id) + self.assertEqual(events[1].stream_ended, True) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h3.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h3.py new file mode 100644 index 000000000000..1e550399b4a3 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_h3.py @@ -0,0 +1,1292 @@ +import binascii +from unittest import TestCase + +from aioquic.buffer import encode_uint_var +from aioquic.h3.connection import ( + H3_ALPN, + ErrorCode, + FrameType, + FrameUnexpected, + H3Connection, + StreamType, + encode_frame, +) +from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived +from aioquic.h3.exceptions import NoAvailablePushIDError +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import StreamDataReceived +from aioquic.quic.logger import QuicLogger + +from .test_connection import client_and_server, transfer + + +def h3_client_and_server(): + return client_and_server( + client_options={"alpn_protocols": H3_ALPN}, + server_options={"alpn_protocols": H3_ALPN}, + ) + + +def h3_transfer(quic_sender, h3_receiver): + quic_receiver = h3_receiver._quic + if hasattr(quic_sender, "stream_queue"): + quic_receiver._events.extend(quic_sender.stream_queue) + quic_sender.stream_queue.clear() + else: + transfer(quic_sender, quic_receiver) + + # process QUIC events + http_events = [] + event = quic_receiver.next_event() + while event is not None: + http_events.extend(h3_receiver.handle_event(event)) + event = quic_receiver.next_event() + return http_events + + +class FakeQuicConnection: + def __init__(self, configuration): + self.closed = None + self.configuration = configuration + self.stream_queue = [] + self._events = [] + self._next_stream_bidi = 0 if configuration.is_client else 1 + self._next_stream_uni = 2 if configuration.is_client else 3 + self._quic_logger = QuicLogger().start_trace( + is_client=configuration.is_client, odcid=b"" + ) + + def close(self, error_code, reason_phrase): + self.closed = (error_code, reason_phrase) + + def get_next_available_stream_id(self, is_unidirectional=False): + if is_unidirectional: + stream_id = self._next_stream_uni + self._next_stream_uni += 4 + else: + stream_id = self._next_stream_bidi + self._next_stream_bidi += 4 + return stream_id + + def next_event(self): + try: + return self._events.pop(0) + except IndexError: + return None + + def send_stream_data(self, stream_id, data, end_stream=False): + # chop up data into individual bytes + for c in data: + self.stream_queue.append( + StreamDataReceived( + data=bytes([c]), end_stream=False, stream_id=stream_id + ) + ) + if end_stream: + self.stream_queue.append( + StreamDataReceived(data=b"", end_stream=end_stream, stream_id=stream_id) + ) + + +class H3ConnectionTest(TestCase): + maxDiff = None + + def _make_request(self, h3_client, h3_server): + quic_client = h3_client._quic + quic_server = h3_server._quic + + # send request + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + ) + h3_client.send_data(stream_id=stream_id, data=b"", end_stream=True) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + stream_id=stream_id, + stream_ended=False, + ), + DataReceived(data=b"", stream_id=stream_id, stream_ended=True), + ], + ) + + # send response + h3_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + (b"x-foo", b"server"), + ], + ) + h3_server.send_data( + stream_id=stream_id, + data=b"hello", + end_stream=True, + ) + + # receive response + events = h3_transfer(quic_server, h3_client) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + (b"x-foo", b"server"), + ], + stream_id=stream_id, + stream_ended=False, + ), + DataReceived( + data=b"hello", + stream_id=stream_id, + stream_ended=True, + ), + ], + ) + + def test_handle_control_frame_headers(self): + """ + We should not receive HEADERS on the control stream. + """ + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=2, + data=encode_uint_var(StreamType.CONTROL) + + encode_frame(FrameType.HEADERS, b""), + end_stream=False, + ) + ) + self.assertEqual( + quic_server.closed, + (ErrorCode.HTTP_FRAME_UNEXPECTED, "Invalid frame type on control stream"), + ) + + def test_handle_control_frame_max_push_id_from_server(self): + """ + A client should not receive MAX_PUSH_ID on the control stream. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=3, + data=encode_uint_var(StreamType.CONTROL) + + encode_frame(FrameType.MAX_PUSH_ID, b""), + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, + (ErrorCode.HTTP_FRAME_UNEXPECTED, "Servers must not send MAX_PUSH_ID"), + ) + + def test_handle_control_stream_duplicate(self): + """ + We must only receive a single control stream. + """ + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + # receive a first control stream + h3_server.handle_event( + StreamDataReceived( + stream_id=2, data=encode_uint_var(StreamType.CONTROL), end_stream=False + ) + ) + + # receive a second control stream + h3_server.handle_event( + StreamDataReceived( + stream_id=6, data=encode_uint_var(StreamType.CONTROL), end_stream=False + ) + ) + self.assertEqual( + quic_server.closed, + ( + ErrorCode.HTTP_STREAM_CREATION_ERROR, + "Only one control stream is allowed", + ), + ) + + def test_handle_push_frame_wrong_frame_type(self): + """ + We should not received SETTINGS on a push stream. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=15, + data=encode_uint_var(StreamType.PUSH) + + encode_uint_var(0) # push ID + + encode_frame(FrameType.SETTINGS, b""), + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, + (ErrorCode.HTTP_FRAME_UNEXPECTED, "Invalid frame type on push stream"), + ) + + def test_handle_qpack_decoder_duplicate(self): + """ + We must only receive a single QPACK decoder stream. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + # receive a first decoder stream + h3_client.handle_event( + StreamDataReceived( + stream_id=11, + data=encode_uint_var(StreamType.QPACK_DECODER), + end_stream=False, + ) + ) + + # receive a second decoder stream + h3_client.handle_event( + StreamDataReceived( + stream_id=15, + data=encode_uint_var(StreamType.QPACK_DECODER), + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, + ( + ErrorCode.HTTP_STREAM_CREATION_ERROR, + "Only one QPACK decoder stream is allowed", + ), + ) + + def test_handle_qpack_decoder_stream_error(self): + """ + Receiving garbage on the QPACK decoder stream triggers an exception. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=11, + data=encode_uint_var(StreamType.QPACK_DECODER) + b"\x00", + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, (ErrorCode.HTTP_QPACK_DECODER_STREAM_ERROR, "") + ) + + def test_handle_qpack_encoder_duplicate(self): + """ + We must only receive a single QPACK encoder stream. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + # receive a first encoder stream + h3_client.handle_event( + StreamDataReceived( + stream_id=11, + data=encode_uint_var(StreamType.QPACK_ENCODER), + end_stream=False, + ) + ) + + # receive a second encoder stream + h3_client.handle_event( + StreamDataReceived( + stream_id=15, + data=encode_uint_var(StreamType.QPACK_ENCODER), + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, + ( + ErrorCode.HTTP_STREAM_CREATION_ERROR, + "Only one QPACK encoder stream is allowed", + ), + ) + + def test_handle_qpack_encoder_stream_error(self): + """ + Receiving garbage on the QPACK encoder stream triggers an exception. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=7, + data=encode_uint_var(StreamType.QPACK_ENCODER) + b"\x00", + end_stream=False, + ) + ) + self.assertEqual( + quic_client.closed, (ErrorCode.HTTP_QPACK_ENCODER_STREAM_ERROR, "") + ) + + def test_handle_request_frame_bad_headers(self): + """ + We should not receive HEADERS which cannot be decoded. + """ + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=0, data=encode_frame(FrameType.HEADERS, b""), end_stream=False + ) + ) + self.assertEqual( + quic_server.closed, (ErrorCode.HTTP_QPACK_DECOMPRESSION_FAILED, "") + ) + + def test_handle_request_frame_data_before_headers(self): + """ + We should not receive DATA before receiving headers. + """ + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=0, data=encode_frame(FrameType.DATA, b""), end_stream=False + ) + ) + self.assertEqual( + quic_server.closed, + ( + ErrorCode.HTTP_FRAME_UNEXPECTED, + "DATA frame is not allowed in this state", + ), + ) + + def test_handle_request_frame_headers_after_trailers(self): + """ + We should not receive HEADERS after receiving trailers. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + ) + h3_client.send_headers( + stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=True + ) + h3_transfer(quic_client, h3_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=0, data=encode_frame(FrameType.HEADERS, b""), end_stream=False + ) + ) + self.assertEqual( + quic_server.closed, + ( + ErrorCode.HTTP_FRAME_UNEXPECTED, + "HEADERS frame is not allowed in this state", + ), + ) + + def test_handle_request_frame_push_promise_from_client(self): + """ + A server should not receive PUSH_PROMISE on a request stream. + """ + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=0, + data=encode_frame(FrameType.PUSH_PROMISE, b""), + end_stream=False, + ) + ) + self.assertEqual( + quic_server.closed, + (ErrorCode.HTTP_FRAME_UNEXPECTED, "Clients must not send PUSH_PROMISE"), + ) + + def test_handle_request_frame_wrong_frame_type(self): + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + h3_server = H3Connection(quic_server) + + h3_server.handle_event( + StreamDataReceived( + stream_id=0, + data=encode_frame(FrameType.SETTINGS, b""), + end_stream=False, + ) + ) + self.assertEqual( + quic_server.closed, + (ErrorCode.HTTP_FRAME_UNEXPECTED, "Invalid frame type on request stream"), + ) + + def test_request(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # make first request + self._make_request(h3_client, h3_server) + + # make second request + self._make_request(h3_client, h3_server) + + # make third request -> dynamic table + self._make_request(h3_client, h3_server) + + def test_request_headers_only(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"HEAD"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + end_stream=True, + ) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"HEAD"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + stream_id=stream_id, + stream_ended=True, + ) + ], + ) + + # send response + h3_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + (b"x-foo", b"server"), + ], + end_stream=True, + ) + + # receive response + events = h3_transfer(quic_server, h3_client) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + (b"x-foo", b"server"), + ], + stream_id=stream_id, + stream_ended=True, + ) + ], + ) + + def test_request_fragmented_frame(self): + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + quic_server = FakeQuicConnection( + configuration=QuicConfiguration(is_client=False) + ) + + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + ) + h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=True) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + (b"x-foo", b"client"), + ], + stream_id=stream_id, + stream_ended=False, + ), + DataReceived(data=b"h", stream_id=0, stream_ended=False), + DataReceived(data=b"e", stream_id=0, stream_ended=False), + DataReceived(data=b"l", stream_id=0, stream_ended=False), + DataReceived(data=b"l", stream_id=0, stream_ended=False), + DataReceived(data=b"o", stream_id=0, stream_ended=False), + DataReceived(data=b"", stream_id=0, stream_ended=True), + ], + ) + + # send push promise + push_stream_id = h3_server.send_push_promise( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.txt"), + ], + ) + self.assertEqual(push_stream_id, 15) + + # send response + h3_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + end_stream=False, + ) + h3_server.send_data(stream_id=stream_id, data=b"html", end_stream=True) + + #  fulfill push promise + h3_server.send_headers( + stream_id=push_stream_id, + headers=[(b":status", b"200"), (b"content-type", b"text/plain")], + end_stream=False, + ) + h3_server.send_data(stream_id=push_stream_id, data=b"text", end_stream=True) + + # receive push promise / reponse + events = h3_transfer(quic_server, h3_client) + self.assertEqual( + events, + [ + PushPromiseReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.txt"), + ], + push_id=0, + stream_id=stream_id, + ), + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + stream_id=0, + stream_ended=False, + ), + DataReceived(data=b"h", stream_id=0, stream_ended=False), + DataReceived(data=b"t", stream_id=0, stream_ended=False), + DataReceived(data=b"m", stream_id=0, stream_ended=False), + DataReceived(data=b"l", stream_id=0, stream_ended=False), + DataReceived(data=b"", stream_id=0, stream_ended=True), + HeadersReceived( + headers=[(b":status", b"200"), (b"content-type", b"text/plain")], + stream_id=15, + stream_ended=False, + push_id=0, + ), + DataReceived(data=b"t", stream_id=15, stream_ended=False, push_id=0), + DataReceived(data=b"e", stream_id=15, stream_ended=False, push_id=0), + DataReceived(data=b"x", stream_id=15, stream_ended=False, push_id=0), + DataReceived(data=b"t", stream_id=15, stream_ended=False, push_id=0), + DataReceived(data=b"", stream_id=15, stream_ended=True, push_id=0), + ], + ) + + def test_request_with_server_push(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + end_stream=True, + ) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + stream_id=stream_id, + stream_ended=True, + ) + ], + ) + + # send push promises + push_stream_id_css = h3_server.send_push_promise( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.css"), + ], + ) + self.assertEqual(push_stream_id_css, 15) + + push_stream_id_js = h3_server.send_push_promise( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.js"), + ], + ) + self.assertEqual(push_stream_id_js, 19) + + # send response + h3_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + end_stream=False, + ) + h3_server.send_data( + stream_id=stream_id, + data=b"hello", + end_stream=True, + ) + + #  fulfill push promises + h3_server.send_headers( + stream_id=push_stream_id_css, + headers=[(b":status", b"200"), (b"content-type", b"text/css")], + end_stream=False, + ) + h3_server.send_data( + stream_id=push_stream_id_css, + data=b"body { color: pink }", + end_stream=True, + ) + + h3_server.send_headers( + stream_id=push_stream_id_js, + headers=[ + (b":status", b"200"), + (b"content-type", b"application/javascript"), + ], + end_stream=False, + ) + h3_server.send_data( + stream_id=push_stream_id_js, data=b"alert('howdee');", end_stream=True + ) + + # receive push promises, response and push responses + + events = h3_transfer(quic_server, h3_client) + self.assertEqual( + events, + [ + PushPromiseReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.css"), + ], + push_id=0, + stream_id=stream_id, + ), + PushPromiseReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/app.js"), + ], + push_id=1, + stream_id=stream_id, + ), + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + stream_id=stream_id, + stream_ended=False, + ), + DataReceived( + data=b"hello", + stream_id=stream_id, + stream_ended=True, + ), + HeadersReceived( + headers=[(b":status", b"200"), (b"content-type", b"text/css")], + push_id=0, + stream_id=push_stream_id_css, + stream_ended=False, + ), + DataReceived( + data=b"body { color: pink }", + push_id=0, + stream_id=push_stream_id_css, + stream_ended=True, + ), + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"application/javascript"), + ], + push_id=1, + stream_id=push_stream_id_js, + stream_ended=False, + ), + DataReceived( + data=b"alert('howdee');", + push_id=1, + stream_id=push_stream_id_js, + stream_ended=True, + ), + ], + ) + + def test_request_with_server_push_max_push_id(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # send request + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + end_stream=True, + ) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + stream_id=stream_id, + stream_ended=True, + ) + ], + ) + + # send push promises + for i in range(0, 8): + h3_server.send_push_promise( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", "/{}.css".format(i).encode("ascii")), + ], + ) + + # send one too many + with self.assertRaises(NoAvailablePushIDError): + h3_server.send_push_promise( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/8.css"), + ], + ) + + def test_send_data_after_trailers(self): + """ + We should not send DATA after trailers. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + ) + h3_client.send_headers( + stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=False + ) + with self.assertRaises(FrameUnexpected): + h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=False) + + def test_send_data_before_headers(self): + """ + We should not send DATA before headers. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + stream_id = quic_client.get_next_available_stream_id() + with self.assertRaises(FrameUnexpected): + h3_client.send_data(stream_id=stream_id, data=b"hello", end_stream=False) + + def test_send_headers_after_trailers(self): + """ + We should not send HEADERS after trailers. + """ + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + ) + h3_client.send_headers( + stream_id=stream_id, headers=[(b"x-some-trailer", b"foo")], end_stream=False + ) + with self.assertRaises(FrameUnexpected): + h3_client.send_headers( + stream_id=stream_id, + headers=[(b"x-other-trailer", b"foo")], + end_stream=False, + ) + + def test_blocked_stream(self): + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=3, + data=binascii.unhexlify( + "0004170150000680020000074064091040bcc0000000faceb00c" + ), + end_stream=False, + ) + ) + h3_client.handle_event( + StreamDataReceived(stream_id=7, data=b"\x02", end_stream=False) + ) + h3_client.handle_event( + StreamDataReceived(stream_id=11, data=b"\x03", end_stream=False) + ) + h3_client.handle_event( + StreamDataReceived( + stream_id=0, data=binascii.unhexlify("01040280d910"), end_stream=False + ) + ) + h3_client.handle_event( + StreamDataReceived( + stream_id=0, + data=binascii.unhexlify( + "00408d796f752072656163686564206d766673742e6e65742c20726561636820" + "746865202f6563686f20656e64706f696e7420666f7220616e206563686f2072" + "6573706f6e7365207175657279202f3c6e756d6265723e20656e64706f696e74" + "7320666f722061207661726961626c652073697a6520726573706f6e73652077" + "6974682072616e646f6d206279746573" + ), + end_stream=True, + ) + ) + self.assertEqual( + h3_client.handle_event( + StreamDataReceived( + stream_id=7, + data=binascii.unhexlify( + "3fe101c696d07abe941094cb6d0a08017d403971966e32ca98b46f" + ), + end_stream=False, + ) + ), + [ + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"date", b"Mon, 22 Jul 2019 06:33:33 GMT"), + ], + stream_id=0, + stream_ended=False, + ), + DataReceived( + data=( + b"you reached mvfst.net, reach the /echo endpoint for an " + b"echo response query / endpoints for a variable " + b"size response with random bytes" + ), + stream_id=0, + stream_ended=True, + ), + ], + ) + + def test_blocked_stream_trailer(self): + quic_client = FakeQuicConnection( + configuration=QuicConfiguration(is_client=True) + ) + h3_client = H3Connection(quic_client) + + h3_client.handle_event( + StreamDataReceived( + stream_id=3, + data=binascii.unhexlify( + "0004170150000680020000074064091040bcc0000000faceb00c" + ), + end_stream=False, + ) + ) + h3_client.handle_event( + StreamDataReceived(stream_id=7, data=b"\x02", end_stream=False) + ) + h3_client.handle_event( + StreamDataReceived(stream_id=11, data=b"\x03", end_stream=False) + ) + + self.assertEqual( + h3_client.handle_event( + StreamDataReceived( + stream_id=0, + data=binascii.unhexlify( + "011b0000d95696d07abe941094cb6d0a08017d403971966e32ca98b46f" + ), + end_stream=False, + ) + ), + [ + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"date", b"Mon, 22 Jul 2019 06:33:33 GMT"), + ], + stream_id=0, + stream_ended=False, + ) + ], + ) + + self.assertEqual( + h3_client.handle_event( + StreamDataReceived( + stream_id=0, + data=binascii.unhexlify( + "00408d796f752072656163686564206d766673742e6e65742c20726561636820" + "746865202f6563686f20656e64706f696e7420666f7220616e206563686f2072" + "6573706f6e7365207175657279202f3c6e756d6265723e20656e64706f696e74" + "7320666f722061207661726961626c652073697a6520726573706f6e73652077" + "6974682072616e646f6d206279746573" + ), + end_stream=False, + ) + ), + [ + DataReceived( + data=( + b"you reached mvfst.net, reach the /echo endpoint for an " + b"echo response query / endpoints for a variable " + b"size response with random bytes" + ), + stream_id=0, + stream_ended=False, + ) + ], + ) + + self.assertEqual( + h3_client.handle_event( + StreamDataReceived( + stream_id=0, data=binascii.unhexlify("0103028010"), end_stream=True + ) + ), + [], + ) + + self.assertEqual( + h3_client.handle_event( + StreamDataReceived( + stream_id=7, + data=binascii.unhexlify("6af2b20f49564d833505b38294e7"), + end_stream=False, + ) + ), + [ + HeadersReceived( + headers=[(b"x-some-trailer", b"foo")], + stream_id=0, + stream_ended=True, + push_id=None, + ) + ], + ) + + def test_uni_stream_grease(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_server = H3Connection(quic_server) + + quic_client.send_stream_data( + 14, b"\xff\xff\xff\xff\xff\xff\xff\xfeGREASE is the word" + ) + self.assertEqual(h3_transfer(quic_client, h3_server), []) + + def test_request_with_trailers(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_client = H3Connection(quic_client) + h3_server = H3Connection(quic_server) + + # send request with trailers + stream_id = quic_client.get_next_available_stream_id() + h3_client.send_headers( + stream_id=stream_id, + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + end_stream=False, + ) + h3_client.send_headers( + stream_id=stream_id, + headers=[(b"x-some-trailer", b"foo")], + end_stream=True, + ) + + # receive request + events = h3_transfer(quic_client, h3_server) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", b"/"), + ], + stream_id=stream_id, + stream_ended=False, + ), + HeadersReceived( + headers=[(b"x-some-trailer", b"foo")], + stream_id=stream_id, + stream_ended=True, + ), + ], + ) + + # send response + h3_server.send_headers( + stream_id=stream_id, + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + end_stream=False, + ) + h3_server.send_data( + stream_id=stream_id, + data=b"hello", + end_stream=False, + ) + h3_server.send_headers( + stream_id=stream_id, + headers=[(b"x-some-trailer", b"bar")], + end_stream=True, + ) + + # receive response + events = h3_transfer(quic_server, h3_client) + self.assertEqual( + events, + [ + HeadersReceived( + headers=[ + (b":status", b"200"), + (b"content-type", b"text/html; charset=utf-8"), + ], + stream_id=stream_id, + stream_ended=False, + ), + DataReceived( + data=b"hello", + stream_id=stream_id, + stream_ended=False, + ), + HeadersReceived( + headers=[(b"x-some-trailer", b"bar")], + stream_id=stream_id, + stream_ended=True, + ), + ], + ) + + def test_uni_stream_type(self): + with h3_client_and_server() as (quic_client, quic_server): + h3_server = H3Connection(quic_server) + + # unknown stream type 9 + stream_id = quic_client.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 2) + quic_client.send_stream_data(stream_id, b"\x09") + self.assertEqual(h3_transfer(quic_client, h3_server), []) + self.assertEqual(list(h3_server._stream.keys()), [2]) + self.assertEqual(h3_server._stream[2].buffer, b"") + self.assertEqual(h3_server._stream[2].stream_type, 9) + + # unknown stream type 64, one byte at a time + stream_id = quic_client.get_next_available_stream_id(is_unidirectional=True) + self.assertEqual(stream_id, 6) + + quic_client.send_stream_data(stream_id, b"\x40") + self.assertEqual(h3_transfer(quic_client, h3_server), []) + self.assertEqual(list(h3_server._stream.keys()), [2, 6]) + self.assertEqual(h3_server._stream[2].buffer, b"") + self.assertEqual(h3_server._stream[2].stream_type, 9) + self.assertEqual(h3_server._stream[6].buffer, b"\x40") + self.assertEqual(h3_server._stream[6].stream_type, None) + + quic_client.send_stream_data(stream_id, b"\x40") + self.assertEqual(h3_transfer(quic_client, h3_server), []) + self.assertEqual(list(h3_server._stream.keys()), [2, 6]) + self.assertEqual(h3_server._stream[2].buffer, b"") + self.assertEqual(h3_server._stream[2].stream_type, 9) + self.assertEqual(h3_server._stream[6].buffer, b"") + self.assertEqual(h3_server._stream[6].stream_type, 64) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_logger.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_logger.py new file mode 100644 index 000000000000..bb7c520938e2 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_logger.py @@ -0,0 +1,37 @@ +from unittest import TestCase + +from aioquic.quic.logger import QuicLogger + + +class QuicLoggerTest(TestCase): + def test_empty(self): + logger = QuicLogger() + self.assertEqual(logger.to_dict(), {"qlog_version": "draft-01", "traces": []}) + + def test_empty_trace(self): + logger = QuicLogger() + trace = logger.start_trace(is_client=True, odcid=bytes(8)) + logger.end_trace(trace) + self.assertEqual( + logger.to_dict(), + { + "qlog_version": "draft-01", + "traces": [ + { + "common_fields": { + "ODCID": "0000000000000000", + "reference_time": "0", + }, + "configuration": {"time_units": "us"}, + "event_fields": [ + "relative_time", + "category", + "event_type", + "data", + ], + "events": [], + "vantage_point": {"name": "aioquic", "type": "client"}, + } + ], + }, + ) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet.py new file mode 100644 index 000000000000..269b7aa9a619 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet.py @@ -0,0 +1,513 @@ +import binascii +from unittest import TestCase + +from aioquic.buffer import Buffer, BufferReadError +from aioquic.quic import packet +from aioquic.quic.packet import ( + PACKET_TYPE_INITIAL, + PACKET_TYPE_RETRY, + QuicPreferredAddress, + QuicProtocolVersion, + QuicTransportParameters, + decode_packet_number, + encode_quic_version_negotiation, + get_retry_integrity_tag, + pull_quic_header, + pull_quic_preferred_address, + pull_quic_transport_parameters, + push_quic_preferred_address, + push_quic_transport_parameters, +) + +from .utils import load + + +class PacketTest(TestCase): + def test_decode_packet_number(self): + # expected = 0 + for i in range(0, 256): + self.assertEqual(decode_packet_number(i, 8, expected=0), i) + + # expected = 128 + self.assertEqual(decode_packet_number(0, 8, expected=128), 256) + for i in range(1, 256): + self.assertEqual(decode_packet_number(i, 8, expected=128), i) + + # expected = 129 + self.assertEqual(decode_packet_number(0, 8, expected=129), 256) + self.assertEqual(decode_packet_number(1, 8, expected=129), 257) + for i in range(2, 256): + self.assertEqual(decode_packet_number(i, 8, expected=129), i) + + # expected = 256 + for i in range(0, 128): + self.assertEqual(decode_packet_number(i, 8, expected=256), 256 + i) + for i in range(129, 256): + self.assertEqual(decode_packet_number(i, 8, expected=256), i) + + def test_pull_empty(self): + buf = Buffer(data=b"") + with self.assertRaises(BufferReadError): + pull_quic_header(buf, host_cid_length=8) + + def test_pull_initial_client(self): + buf = Buffer(data=load("initial_client.bin")) + header = pull_quic_header(buf, host_cid_length=8) + self.assertTrue(header.is_long_header) + self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25) + self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) + self.assertEqual(header.destination_cid, binascii.unhexlify("858b39368b8e3c6e")) + self.assertEqual(header.source_cid, b"") + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(header.rest_length, 1262) + self.assertEqual(buf.tell(), 18) + + def test_pull_initial_server(self): + buf = Buffer(data=load("initial_server.bin")) + header = pull_quic_header(buf, host_cid_length=8) + self.assertTrue(header.is_long_header) + self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25) + self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) + self.assertEqual(header.destination_cid, b"") + self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479")) + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(header.rest_length, 184) + self.assertEqual(buf.tell(), 18) + + def test_pull_retry(self): + buf = Buffer(data=load("retry.bin")) + header = pull_quic_header(buf, host_cid_length=8) + self.assertTrue(header.is_long_header) + self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25) + self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) + self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) + self.assertEqual( + header.source_cid, + binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"), + ) + self.assertEqual( + header.token, + binascii.unhexlify( + "44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172" + "906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e" + "8530741a99192ee56914c5626998ec0f" + ), + ) + self.assertEqual( + header.integrity_tag, binascii.unhexlify("e1a3c80c797ea401c08fc9c342a2d90d") + ) + self.assertEqual(header.rest_length, 0) + self.assertEqual(buf.tell(), 125) + + # check integrity + self.assertEqual( + get_retry_integrity_tag( + buf.data_slice(0, 109), binascii.unhexlify("fbbd219b7363b64b"), + ), + header.integrity_tag, + ) + + def test_pull_version_negotiation(self): + buf = Buffer(data=load("version_negotiation.bin")) + header = pull_quic_header(buf, host_cid_length=8) + self.assertTrue(header.is_long_header) + self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) + self.assertEqual(header.packet_type, None) + self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) + self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(header.rest_length, 8) + self.assertEqual(buf.tell(), 23) + + def test_pull_long_header_dcid_too_long(self): + buf = Buffer( + data=binascii.unhexlify( + "c6ff0000161500000000000000000000000000000000000000000000004" + "01c514f99ec4bbf1f7a30f9b0c94fef717f1c1d07fec24c99a864da7ede" + ) + ) + with self.assertRaises(ValueError) as cm: + pull_quic_header(buf, host_cid_length=8) + self.assertEqual(str(cm.exception), "Destination CID is too long (21 bytes)") + + def test_pull_long_header_scid_too_long(self): + buf = Buffer( + data=binascii.unhexlify( + "c2ff0000160015000000000000000000000000000000000000000000004" + "01cfcee99ec4bbf1f7a30f9b0c9417b8c263cdd8cc972a4439d68a46320" + ) + ) + with self.assertRaises(ValueError) as cm: + pull_quic_header(buf, host_cid_length=8) + self.assertEqual(str(cm.exception), "Source CID is too long (21 bytes)") + + def test_pull_long_header_no_fixed_bit(self): + buf = Buffer(data=b"\x80\xff\x00\x00\x11\x00\x00") + with self.assertRaises(ValueError) as cm: + pull_quic_header(buf, host_cid_length=8) + self.assertEqual(str(cm.exception), "Packet fixed bit is zero") + + def test_pull_long_header_too_short(self): + buf = Buffer(data=b"\xc0\x00") + with self.assertRaises(BufferReadError): + pull_quic_header(buf, host_cid_length=8) + + def test_pull_short_header(self): + buf = Buffer(data=load("short_header.bin")) + header = pull_quic_header(buf, host_cid_length=8) + self.assertFalse(header.is_long_header) + self.assertEqual(header.version, None) + self.assertEqual(header.packet_type, 0x50) + self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6")) + self.assertEqual(header.source_cid, b"") + self.assertEqual(header.token, b"") + self.assertEqual(header.integrity_tag, b"") + self.assertEqual(header.rest_length, 12) + self.assertEqual(buf.tell(), 9) + + def test_pull_short_header_no_fixed_bit(self): + buf = Buffer(data=b"\x00") + with self.assertRaises(ValueError) as cm: + pull_quic_header(buf, host_cid_length=8) + self.assertEqual(str(cm.exception), "Packet fixed bit is zero") + + def test_encode_quic_version_negotiation(self): + data = encode_quic_version_negotiation( + destination_cid=binascii.unhexlify("9aac5a49ba87a849"), + source_cid=binascii.unhexlify("f92f4336fa951ba1"), + supported_versions=[0x45474716, QuicProtocolVersion.DRAFT_25], + ) + self.assertEqual(data[1:], load("version_negotiation.bin")[1:]) + + +class ParamsTest(TestCase): + maxDiff = None + + def test_params(self): + data = binascii.unhexlify( + "010267100210cc2fd6e7d97a53ab5be85b28d75c8008030247e404048005fff" + "a05048000ffff06048000ffff0801060a01030b0119" + ) + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual( + params, + QuicTransportParameters( + idle_timeout=10000, + stateless_reset_token=b"\xcc/\xd6\xe7\xd9zS\xab[\xe8[(\xd7\\\x80\x08", + max_packet_size=2020, + initial_max_data=393210, + initial_max_stream_data_bidi_local=65535, + initial_max_stream_data_bidi_remote=65535, + initial_max_stream_data_uni=None, + initial_max_streams_bidi=6, + initial_max_streams_uni=None, + ack_delay_exponent=3, + max_ack_delay=25, + ), + ) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual(len(buf.data), len(data)) + + def test_params_legacy(self): + data = binascii.unhexlify( + "004700020010cc2fd6e7d97a53ab5be85b28d75c80080008000106000100026" + "710000600048000ffff000500048000ffff000400048005fffa000a00010300" + "0b0001190003000247e4" + ) + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual( + params, + QuicTransportParameters( + idle_timeout=10000, + stateless_reset_token=b"\xcc/\xd6\xe7\xd9zS\xab[\xe8[(\xd7\\\x80\x08", + max_packet_size=2020, + initial_max_data=393210, + initial_max_stream_data_bidi_local=65535, + initial_max_stream_data_bidi_remote=65535, + initial_max_stream_data_uni=None, + initial_max_streams_bidi=6, + initial_max_streams_uni=None, + ack_delay_exponent=3, + max_ack_delay=25, + ), + ) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual(len(buf.data), len(data)) + + def test_params_disable_active_migration(self): + data = binascii.unhexlify("0c00") + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual(params, QuicTransportParameters(disable_active_migration=True)) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual(buf.data, data) + + def test_params_disable_active_migration_legacy(self): + data = binascii.unhexlify("0004000c0000") + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual(params, QuicTransportParameters(disable_active_migration=True)) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual(buf.data, data) + + def test_params_preferred_address(self): + data = binascii.unhexlify( + "0d3b8ba27b8611532400890200000000f03c91fffe69a45411531262c4518d6" + "3013f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" + ) + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual( + params, + QuicTransportParameters( + preferred_address=QuicPreferredAddress( + ipv4_address=("139.162.123.134", 4435), + ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435), + connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", + stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", + ), + ), + ) + + # serialize + buf = Buffer(capacity=1000) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual(buf.data, data) + + def test_params_preferred_address_legacy(self): + data = binascii.unhexlify( + "003f000d003b8ba27b8611532400890200000000f03c91fffe69a4541153126" + "2c4518d63013f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48" + "ecb4" + ) + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual( + params, + QuicTransportParameters( + preferred_address=QuicPreferredAddress( + ipv4_address=("139.162.123.134", 4435), + ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435), + connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", + stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", + ), + ), + ) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_transport_parameters( + buf, params, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual(buf.data, data) + + def test_params_unknown(self): + data = binascii.unhexlify("8000ff000100") + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_27 + ) + self.assertEqual(params, QuicTransportParameters()) + + def test_params_unknown_legacy(self): + # fb.mvfst.net sends a proprietary parameter 65280 + data = binascii.unhexlify( + "006400050004800104000006000480010400000700048001040000040004801" + "0000000080008c0000000ffffffff00090008c0000000ffffffff0001000480" + "00ea60000a00010300030002500000020010616161616262626263636363646" + "46464ff00000100" + ) + + # parse + buf = Buffer(data=data) + params = pull_quic_transport_parameters( + buf, protocol_version=QuicProtocolVersion.DRAFT_25 + ) + self.assertEqual( + params, + QuicTransportParameters( + idle_timeout=60000, + stateless_reset_token=b"aaaabbbbccccdddd", + max_packet_size=4096, + initial_max_data=1048576, + initial_max_stream_data_bidi_local=66560, + initial_max_stream_data_bidi_remote=66560, + initial_max_stream_data_uni=66560, + initial_max_streams_bidi=4294967295, + initial_max_streams_uni=4294967295, + ack_delay_exponent=3, + ), + ) + + def test_preferred_address_ipv4_only(self): + data = binascii.unhexlify( + "8ba27b8611530000000000000000000000000000000000001262c4518d63013" + "f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" + ) + + # parse + buf = Buffer(data=data) + preferred_address = pull_quic_preferred_address(buf) + self.assertEqual( + preferred_address, + QuicPreferredAddress( + ipv4_address=("139.162.123.134", 4435), + ipv6_address=None, + connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", + stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", + ), + ) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_preferred_address(buf, preferred_address) + self.assertEqual(buf.data, data) + + def test_preferred_address_ipv6_only(self): + data = binascii.unhexlify( + "0000000000002400890200000000f03c91fffe69a45411531262c4518d63013" + "f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4" + ) + + # parse + buf = Buffer(data=data) + preferred_address = pull_quic_preferred_address(buf) + self.assertEqual( + preferred_address, + QuicPreferredAddress( + ipv4_address=None, + ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435), + connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7", + stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4", + ), + ) + + # serialize + buf = Buffer(capacity=len(data)) + push_quic_preferred_address(buf, preferred_address) + self.assertEqual(buf.data, data) + + +class FrameTest(TestCase): + def test_ack_frame(self): + data = b"\x00\x02\x00\x00" + + # parse + buf = Buffer(data=data) + rangeset, delay = packet.pull_ack_frame(buf) + self.assertEqual(list(rangeset), [range(0, 1)]) + self.assertEqual(delay, 2) + + # serialize + buf = Buffer(capacity=len(data)) + packet.push_ack_frame(buf, rangeset, delay) + self.assertEqual(buf.data, data) + + def test_ack_frame_with_one_range(self): + data = b"\x02\x02\x01\x00\x00\x00" + + # parse + buf = Buffer(data=data) + rangeset, delay = packet.pull_ack_frame(buf) + self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)]) + self.assertEqual(delay, 2) + + # serialize + buf = Buffer(capacity=len(data)) + packet.push_ack_frame(buf, rangeset, delay) + self.assertEqual(buf.data, data) + + def test_ack_frame_with_one_range_2(self): + data = b"\x05\x02\x01\x00\x00\x03" + + # parse + buf = Buffer(data=data) + rangeset, delay = packet.pull_ack_frame(buf) + self.assertEqual(list(rangeset), [range(0, 4), range(5, 6)]) + self.assertEqual(delay, 2) + + # serialize + buf = Buffer(capacity=len(data)) + packet.push_ack_frame(buf, rangeset, delay) + self.assertEqual(buf.data, data) + + def test_ack_frame_with_one_range_3(self): + data = b"\x05\x02\x01\x00\x01\x02" + + # parse + buf = Buffer(data=data) + rangeset, delay = packet.pull_ack_frame(buf) + self.assertEqual(list(rangeset), [range(0, 3), range(5, 6)]) + self.assertEqual(delay, 2) + + # serialize + buf = Buffer(capacity=len(data)) + packet.push_ack_frame(buf, rangeset, delay) + self.assertEqual(buf.data, data) + + def test_ack_frame_with_two_ranges(self): + data = b"\x04\x02\x02\x00\x00\x00\x00\x00" + + # parse + buf = Buffer(data=data) + rangeset, delay = packet.pull_ack_frame(buf) + self.assertEqual(list(rangeset), [range(0, 1), range(2, 3), range(4, 5)]) + self.assertEqual(delay, 2) + + # serialize + buf = Buffer(capacity=len(data)) + packet.push_ack_frame(buf, rangeset, delay) + self.assertEqual(buf.data, data) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet_builder.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet_builder.py new file mode 100644 index 000000000000..d3e2311d9912 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_packet_builder.py @@ -0,0 +1,569 @@ +from unittest import TestCase + +from aioquic.quic.crypto import CryptoPair +from aioquic.quic.packet import ( + PACKET_TYPE_HANDSHAKE, + PACKET_TYPE_INITIAL, + PACKET_TYPE_ONE_RTT, + QuicFrameType, + QuicProtocolVersion, +) +from aioquic.quic.packet_builder import ( + QuicPacketBuilder, + QuicPacketBuilderStop, + QuicSentPacket, +) +from aioquic.tls import Epoch + + +def create_builder(is_client=False): + return QuicPacketBuilder( + host_cid=bytes(8), + is_client=is_client, + packet_number=0, + peer_cid=bytes(8), + peer_token=b"", + spin_bit=False, + version=QuicProtocolVersion.DRAFT_25, + ) + + +def create_crypto(): + crypto = CryptoPair() + crypto.setup_initial(bytes(8), is_client=True, version=QuicProtocolVersion.DRAFT_25) + return crypto + + +class QuicPacketBuilderTest(TestCase): + def test_long_header_empty(self): + builder = create_builder() + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 0) + self.assertEqual(packets, []) + + # check builder + self.assertEqual(builder.packet_number, 0) + + def test_long_header_padding(self): + builder = create_builder(is_client=True) + crypto = create_crypto() + + # INITIAL, fully padded + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(100)) + self.assertFalse(builder.packet_is_empty) + + # INITIAL, empty + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 1280) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + ) + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_long_header_initial_client_2(self): + builder = create_builder(is_client=True) + crypto = create_crypto() + + # INITIAL, full length + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + # INITIAL + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(100)) + self.assertFalse(builder.packet_is_empty) + + # INITIAL, empty + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 2) + self.assertEqual(len(datagrams[0]), 1280) + self.assertEqual(len(datagrams[1]), 1280) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + ), + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=1, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 2) + + def test_long_header_initial_server(self): + builder = create_builder() + crypto = create_crypto() + + # INITIAL + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(100)) + self.assertFalse(builder.packet_is_empty) + + # INITIAL, empty + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 145) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=145, + ) + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_long_header_then_short_header(self): + builder = create_builder() + crypto = create_crypto() + + # INITIAL, full length + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + # INITIAL, empty + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertTrue(builder.packet_is_empty) + + # ONE_RTT, full length + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 1253) + buf = builder.start_frame(QuicFrameType.STREAM_BASE) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + # ONE_RTT, empty + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 2) + self.assertEqual(len(datagrams[0]), 1280) + self.assertEqual(len(datagrams[1]), 1280) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + ), + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=1, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 2) + + def test_long_header_then_long_header(self): + builder = create_builder() + crypto = create_crypto() + + # INITIAL + builder.start_packet(PACKET_TYPE_INITIAL, crypto) + self.assertEqual(builder.remaining_flight_space, 1236) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(199)) + self.assertFalse(builder.packet_is_empty) + + # HANDSHAKE + builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + self.assertEqual(builder.remaining_flight_space, 993) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(299)) + self.assertFalse(builder.packet_is_empty) + + # ONE_RTT + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 666) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(299)) + self.assertFalse(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 914) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=244, + ), + QuicSentPacket( + epoch=Epoch.HANDSHAKE, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=1, + packet_type=PACKET_TYPE_HANDSHAKE, + sent_bytes=343, + ), + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=2, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=327, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 3) + + def test_short_header_empty(self): + builder = create_builder() + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 1253) + self.assertTrue(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(datagrams, []) + self.assertEqual(packets, []) + + # check builder + self.assertEqual(builder.packet_number, 0) + + def test_short_header_padding(self): + builder = create_builder() + crypto = create_crypto() + + # ONE_RTT, full length + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 1253) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 1280) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + ) + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_short_header_max_flight_bytes(self): + """ + max_flight_bytes limits sent data. + """ + builder = create_builder() + builder.max_flight_bytes = 1000 + + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 973) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_frame(QuicFrameType.CRYPTO) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 1000) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1000, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_short_header_max_flight_bytes_zero(self): + """ + max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE. + + Check CRYPTO is not allowed. + """ + builder = create_builder() + builder.max_flight_bytes = 0 + + crypto = create_crypto() + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_frame(QuicFrameType.CRYPTO) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 0) + + # check builder + self.assertEqual(builder.packet_number, 0) + + def test_short_header_max_flight_bytes_zero_ack(self): + """ + max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE. + + Check ACK is allowed. + """ + builder = create_builder() + builder.max_flight_bytes = 0 + + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + buf = builder.start_frame(QuicFrameType.ACK) + buf.push_bytes(bytes(64)) + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_frame(QuicFrameType.CRYPTO) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 92) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=False, + is_ack_eliciting=False, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=92, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_short_header_max_total_bytes_1(self): + """ + max_total_bytes doesn't allow any packets. + """ + builder = create_builder() + builder.max_total_bytes = 11 + + crypto = create_crypto() + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(datagrams, []) + self.assertEqual(packets, []) + + # check builder + self.assertEqual(builder.packet_number, 0) + + def test_short_header_max_total_bytes_2(self): + """ + max_total_bytes allows a short packet. + """ + builder = create_builder() + builder.max_total_bytes = 800 + + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 773) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 1) + self.assertEqual(len(datagrams[0]), 800) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=800, + ) + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 1) + + def test_short_header_max_total_bytes_3(self): + builder = create_builder() + builder.max_total_bytes = 2000 + + crypto = create_crypto() + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 1253) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + self.assertEqual(builder.remaining_flight_space, 693) + buf = builder.start_frame(QuicFrameType.CRYPTO) + buf.push_bytes(bytes(builder.remaining_flight_space)) + self.assertFalse(builder.packet_is_empty) + + with self.assertRaises(QuicPacketBuilderStop): + builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + + # check datagrams + datagrams, packets = builder.flush() + self.assertEqual(len(datagrams), 2) + self.assertEqual(len(datagrams[0]), 1280) + self.assertEqual(len(datagrams[1]), 720) + self.assertEqual( + packets, + [ + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + ), + QuicSentPacket( + epoch=Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=1, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=720, + ), + ], + ) + + # check builder + self.assertEqual(builder.packet_number, 2) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_rangeset.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_rangeset.py new file mode 100644 index 000000000000..5b408c8ac6c5 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_rangeset.py @@ -0,0 +1,237 @@ +from unittest import TestCase + +from aioquic.quic.rangeset import RangeSet + + +class RangeSetTest(TestCase): + def test_add_single_duplicate(self): + rangeset = RangeSet() + + rangeset.add(0) + self.assertEqual(list(rangeset), [range(0, 1)]) + + rangeset.add(0) + self.assertEqual(list(rangeset), [range(0, 1)]) + + def test_add_single_ordered(self): + rangeset = RangeSet() + + rangeset.add(0) + self.assertEqual(list(rangeset), [range(0, 1)]) + + rangeset.add(1) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(2) + self.assertEqual(list(rangeset), [range(0, 3)]) + + def test_add_single_merge(self): + rangeset = RangeSet() + + rangeset.add(0) + self.assertEqual(list(rangeset), [range(0, 1)]) + + rangeset.add(2) + self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)]) + + rangeset.add(1) + self.assertEqual(list(rangeset), [range(0, 3)]) + + def test_add_single_reverse(self): + rangeset = RangeSet() + + rangeset.add(2) + self.assertEqual(list(rangeset), [range(2, 3)]) + + rangeset.add(1) + self.assertEqual(list(rangeset), [range(1, 3)]) + + rangeset.add(0) + self.assertEqual(list(rangeset), [range(0, 3)]) + + def test_add_range_ordered(self): + rangeset = RangeSet() + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(2, 4) + self.assertEqual(list(rangeset), [range(0, 4)]) + + rangeset.add(4, 6) + self.assertEqual(list(rangeset), [range(0, 6)]) + + def test_add_range_merge(self): + rangeset = RangeSet() + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(3, 5) + self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)]) + + rangeset.add(2, 3) + self.assertEqual(list(rangeset), [range(0, 5)]) + + def test_add_range_overlap(self): + rangeset = RangeSet() + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(3, 5) + self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)]) + + rangeset.add(1, 5) + self.assertEqual(list(rangeset), [range(0, 5)]) + + def test_add_range_overlap_2(self): + rangeset = RangeSet() + + rangeset.add(2, 4) + rangeset.add(6, 8) + rangeset.add(10, 12) + rangeset.add(16, 18) + self.assertEqual( + list(rangeset), [range(2, 4), range(6, 8), range(10, 12), range(16, 18)] + ) + + rangeset.add(1, 15) + self.assertEqual(list(rangeset), [range(1, 15), range(16, 18)]) + + def test_add_range_reverse(self): + rangeset = RangeSet() + + rangeset.add(6, 8) + self.assertEqual(list(rangeset), [range(6, 8)]) + + rangeset.add(3, 5) + self.assertEqual(list(rangeset), [range(3, 5), range(6, 8)]) + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)]) + + def test_add_range_unordered_contiguous(self): + rangeset = RangeSet() + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(4, 6) + self.assertEqual(list(rangeset), [range(0, 2), range(4, 6)]) + + rangeset.add(2, 4) + self.assertEqual(list(rangeset), [range(0, 6)]) + + def test_add_range_unordered_sparse(self): + rangeset = RangeSet() + + rangeset.add(0, 2) + self.assertEqual(list(rangeset), [range(0, 2)]) + + rangeset.add(6, 8) + self.assertEqual(list(rangeset), [range(0, 2), range(6, 8)]) + + rangeset.add(3, 5) + self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)]) + + def test_subtract(self): + rangeset = RangeSet() + rangeset.add(0, 10) + rangeset.add(20, 30) + + rangeset.subtract(0, 3) + self.assertEqual(list(rangeset), [range(3, 10), range(20, 30)]) + + def test_subtract_no_change(self): + rangeset = RangeSet() + rangeset.add(5, 10) + rangeset.add(15, 20) + rangeset.add(25, 30) + + rangeset.subtract(0, 5) + self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)]) + + rangeset.subtract(10, 15) + self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)]) + + def test_subtract_overlap(self): + rangeset = RangeSet() + rangeset.add(1, 4) + rangeset.add(6, 8) + rangeset.add(10, 20) + rangeset.add(30, 40) + self.assertEqual( + list(rangeset), [range(1, 4), range(6, 8), range(10, 20), range(30, 40)] + ) + + rangeset.subtract(0, 2) + self.assertEqual( + list(rangeset), [range(2, 4), range(6, 8), range(10, 20), range(30, 40)] + ) + + rangeset.subtract(3, 11) + self.assertEqual(list(rangeset), [range(2, 3), range(11, 20), range(30, 40)]) + + def test_subtract_split(self): + rangeset = RangeSet() + rangeset.add(0, 10) + rangeset.subtract(2, 5) + self.assertEqual(list(rangeset), [range(0, 2), range(5, 10)]) + + def test_bool(self): + with self.assertRaises(NotImplementedError): + bool(RangeSet()) + + def test_contains(self): + rangeset = RangeSet() + self.assertFalse(0 in rangeset) + + rangeset = RangeSet([range(0, 1)]) + self.assertTrue(0 in rangeset) + self.assertFalse(1 in rangeset) + + rangeset = RangeSet([range(0, 1), range(3, 6)]) + self.assertTrue(0 in rangeset) + self.assertFalse(1 in rangeset) + self.assertFalse(2 in rangeset) + self.assertTrue(3 in rangeset) + self.assertTrue(4 in rangeset) + self.assertTrue(5 in rangeset) + self.assertFalse(6 in rangeset) + + def test_eq(self): + r0 = RangeSet([range(0, 1)]) + r1 = RangeSet([range(1, 2), range(3, 4)]) + r2 = RangeSet([range(3, 4), range(1, 2)]) + + self.assertTrue(r0 == r0) + self.assertFalse(r0 == r1) + self.assertFalse(r0 == 0) + + self.assertTrue(r1 == r1) + self.assertFalse(r1 == r0) + self.assertTrue(r1 == r2) + self.assertFalse(r1 == 0) + + self.assertTrue(r2 == r2) + self.assertTrue(r2 == r1) + self.assertFalse(r2 == r0) + self.assertFalse(r2 == 0) + + def test_len(self): + rangeset = RangeSet() + self.assertEqual(len(rangeset), 0) + + rangeset = RangeSet([range(0, 1)]) + self.assertEqual(len(rangeset), 1) + + def test_pop(self): + rangeset = RangeSet([range(1, 2), range(3, 4)]) + r = rangeset.shift() + self.assertEqual(r, range(1, 2)) + self.assertEqual(list(rangeset), [range(3, 4)]) + + def test_repr(self): + rangeset = RangeSet([range(1, 2), range(3, 4)]) + self.assertEqual(repr(rangeset), "RangeSet([range(1, 2), range(3, 4)])") diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_recovery.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_recovery.py new file mode 100644 index 000000000000..c313f8fc8a4c --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_recovery.py @@ -0,0 +1,227 @@ +import math +from unittest import TestCase + +from aioquic import tls +from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from aioquic.quic.packet_builder import QuicSentPacket +from aioquic.quic.rangeset import RangeSet +from aioquic.quic.recovery import ( + QuicPacketPacer, + QuicPacketRecovery, + QuicPacketSpace, + QuicRttMonitor, +) + + +def send_probe(): + pass + + +class QuicPacketPacerTest(TestCase): + def setUp(self): + self.pacer = QuicPacketPacer() + + def test_no_measurement(self): + self.assertIsNone(self.pacer.next_send_time(now=0.0)) + self.pacer.update_after_send(now=0.0) + + self.assertIsNone(self.pacer.next_send_time(now=0.0)) + self.pacer.update_after_send(now=0.0) + + def test_with_measurement(self): + self.assertIsNone(self.pacer.next_send_time(now=0.0)) + self.pacer.update_after_send(now=0.0) + + self.pacer.update_rate(congestion_window=1280000, smoothed_rtt=0.05) + self.assertEqual(self.pacer.bucket_max, 0.0008) + self.assertEqual(self.pacer.bucket_time, 0.0) + self.assertEqual(self.pacer.packet_time, 0.00005) + + # 16 packets + for i in range(16): + self.assertIsNone(self.pacer.next_send_time(now=1.0)) + self.pacer.update_after_send(now=1.0) + self.assertAlmostEqual(self.pacer.next_send_time(now=1.0), 1.00005) + + # 2 packets + for i in range(2): + self.assertIsNone(self.pacer.next_send_time(now=1.00005)) + self.pacer.update_after_send(now=1.00005) + self.assertAlmostEqual(self.pacer.next_send_time(now=1.00005), 1.0001) + + # 1 packet + self.assertIsNone(self.pacer.next_send_time(now=1.0001)) + self.pacer.update_after_send(now=1.0001) + self.assertAlmostEqual(self.pacer.next_send_time(now=1.0001), 1.00015) + + # 2 packets + for i in range(2): + self.assertIsNone(self.pacer.next_send_time(now=1.00015)) + self.pacer.update_after_send(now=1.00015) + self.assertAlmostEqual(self.pacer.next_send_time(now=1.00015), 1.0002) + + +class QuicPacketRecoveryTest(TestCase): + def setUp(self): + self.INITIAL_SPACE = QuicPacketSpace() + self.HANDSHAKE_SPACE = QuicPacketSpace() + self.ONE_RTT_SPACE = QuicPacketSpace() + + self.recovery = QuicPacketRecovery( + is_client_without_1rtt=False, send_probe=send_probe + ) + self.recovery.spaces = [ + self.INITIAL_SPACE, + self.HANDSHAKE_SPACE, + self.ONE_RTT_SPACE, + ] + + def test_discard_space(self): + self.recovery.discard_space(self.INITIAL_SPACE) + + def test_on_ack_received_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet, space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertTrue(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 10.0) + self.assertEqual(self.recovery._rtt_min, 10.0) + self.assertEqual(self.recovery._rtt_smoothed, 10.0) + + def test_on_ack_received_non_ack_eliciting(self): + packet = QuicSentPacket( + epoch=tls.Epoch.ONE_RTT, + in_flight=True, + is_ack_eliciting=False, + is_crypto_packet=False, + packet_number=0, + packet_type=PACKET_TYPE_ONE_RTT, + sent_bytes=1280, + sent_time=123.45, + ) + space = self.ONE_RTT_SPACE + + #  packet sent + self.recovery.on_packet_sent(packet, space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 1) + + # packet ack'd + self.recovery.on_ack_received( + space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0 + ) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + # check RTT + self.assertFalse(self.recovery._rtt_initialized) + self.assertEqual(self.recovery._rtt_latest, 0.0) + self.assertEqual(self.recovery._rtt_min, math.inf) + self.assertEqual(self.recovery._rtt_smoothed, 0.0) + + def test_on_packet_lost_crypto(self): + packet = QuicSentPacket( + epoch=tls.Epoch.INITIAL, + in_flight=True, + is_ack_eliciting=True, + is_crypto_packet=True, + packet_number=0, + packet_type=PACKET_TYPE_INITIAL, + sent_bytes=1280, + sent_time=0.0, + ) + space = self.INITIAL_SPACE + + self.recovery.on_packet_sent(packet, space) + self.assertEqual(self.recovery.bytes_in_flight, 1280) + self.assertEqual(space.ack_eliciting_in_flight, 1) + self.assertEqual(len(space.sent_packets), 1) + + self.recovery._detect_loss(space, now=1.0) + self.assertEqual(self.recovery.bytes_in_flight, 0) + self.assertEqual(space.ack_eliciting_in_flight, 0) + self.assertEqual(len(space.sent_packets), 0) + + +class QuicRttMonitorTest(TestCase): + def test_monitor(self): + monitor = QuicRttMonitor() + + self.assertFalse(monitor.is_rtt_increasing(rtt=10, now=1000)) + self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) + self.assertFalse(monitor._ready) + + # not taken into account + self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1000)) + self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0]) + self.assertFalse(monitor._ready) + + self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1001)) + self.assertEqual(monitor._samples, [10, 11, 0.0, 0.0, 0.0]) + self.assertFalse(monitor._ready) + + self.assertFalse(monitor.is_rtt_increasing(rtt=12, now=1002)) + self.assertEqual(monitor._samples, [10, 11, 12, 0.0, 0.0]) + self.assertFalse(monitor._ready) + + self.assertFalse(monitor.is_rtt_increasing(rtt=13, now=1003)) + self.assertEqual(monitor._samples, [10, 11, 12, 13, 0.0]) + self.assertFalse(monitor._ready) + + # we now have enough samples + self.assertFalse(monitor.is_rtt_increasing(rtt=14, now=1004)) + self.assertEqual(monitor._samples, [10, 11, 12, 13, 14]) + self.assertTrue(monitor._ready) + + self.assertFalse(monitor.is_rtt_increasing(rtt=20, now=1005)) + self.assertEqual(monitor._increases, 0) + + self.assertFalse(monitor.is_rtt_increasing(rtt=30, now=1006)) + self.assertEqual(monitor._increases, 0) + + self.assertFalse(monitor.is_rtt_increasing(rtt=40, now=1007)) + self.assertEqual(monitor._increases, 0) + + self.assertFalse(monitor.is_rtt_increasing(rtt=50, now=1008)) + self.assertEqual(monitor._increases, 0) + + self.assertFalse(monitor.is_rtt_increasing(rtt=60, now=1009)) + self.assertEqual(monitor._increases, 1) + + self.assertFalse(monitor.is_rtt_increasing(rtt=70, now=1010)) + self.assertEqual(monitor._increases, 2) + + self.assertFalse(monitor.is_rtt_increasing(rtt=80, now=1011)) + self.assertEqual(monitor._increases, 3) + + self.assertFalse(monitor.is_rtt_increasing(rtt=90, now=1012)) + self.assertEqual(monitor._increases, 4) + + self.assertTrue(monitor.is_rtt_increasing(rtt=100, now=1013)) + self.assertEqual(monitor._increases, 5) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_retry.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_retry.py new file mode 100644 index 000000000000..376f023bf93d --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_retry.py @@ -0,0 +1,30 @@ +from unittest import TestCase + +from aioquic.quic.retry import QuicRetryTokenHandler + + +class QuicRetryTokenHandlerTest(TestCase): + def test_retry_token(self): + addr = ("127.0.0.1", 1234) + cid = b"\x08\x07\x06\05\x04\x03\x02\x01" + + handler = QuicRetryTokenHandler() + + # create token + token = handler.create_token(addr, cid) + self.assertIsNotNone(token) + + # validate token - ok + self.assertEqual(handler.validate_token(addr, token), cid) + + # validate token - empty + with self.assertRaises(ValueError) as cm: + handler.validate_token(addr, b"") + self.assertEqual( + str(cm.exception), "Ciphertext length must be equal to key size." + ) + + # validate token - wrong address + with self.assertRaises(ValueError) as cm: + handler.validate_token(("1.2.3.4", 12345), token) + self.assertEqual(str(cm.exception), "Remote address does not match.") diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_stream.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_stream.py new file mode 100644 index 000000000000..3537ae2ed12b --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_stream.py @@ -0,0 +1,446 @@ +from unittest import TestCase + +from aioquic.quic.events import StreamDataReceived +from aioquic.quic.packet import QuicStreamFrame +from aioquic.quic.packet_builder import QuicDeliveryState +from aioquic.quic.stream import QuicStream + + +class QuicStreamTest(TestCase): + def test_recv_empty(self): + stream = QuicStream(stream_id=0) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 0) + + # empty + self.assertEqual(stream.add_frame(QuicStreamFrame(offset=0, data=b"")), None) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 0) + + def test_recv_ordered(self): + stream = QuicStream(stream_id=0) + + # add data at start + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 8) + + # add more data + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345")), + StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 16) + + # add data and fin + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=16, data=b"67890123", fin=True)), + StreamDataReceived(data=b"67890123", end_stream=True, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 24) + + def test_recv_unordered(self): + stream = QuicStream(stream_id=0) + + # add data at offset 8 + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345")), None + ) + self.assertEqual( + bytes(stream._recv_buffer), b"\x00\x00\x00\x00\x00\x00\x00\x0089012345" + ) + self.assertEqual(list(stream._recv_ranges), [range(8, 16)]) + self.assertEqual(stream._recv_buffer_start, 0) + + # add data at offset 0 + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"0123456789012345", end_stream=False, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 16) + + def test_recv_offset_only(self): + stream = QuicStream(stream_id=0) + + # add data at offset 0 + self.assertEqual(stream.add_frame(QuicStreamFrame(offset=0, data=b"")), None) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 0) + + # add data at offset 8 + self.assertEqual(stream.add_frame(QuicStreamFrame(offset=8, data=b"")), None) + self.assertEqual( + bytes(stream._recv_buffer), b"\x00\x00\x00\x00\x00\x00\x00\x00" + ) + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 0) + + def test_recv_already_fully_consumed(self): + stream = QuicStream(stream_id=0) + + # add data at offset 0 + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 8) + + # add data again at offset 0 + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), None + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 8) + + def test_recv_already_partially_consumed(self): + stream = QuicStream(stream_id=0) + + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), + ) + + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"0123456789012345")), + StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0), + ) + self.assertEqual(bytes(stream._recv_buffer), b"") + self.assertEqual(list(stream._recv_ranges), []) + self.assertEqual(stream._recv_buffer_start, 16) + + def test_recv_fin(self): + stream = QuicStream(stream_id=0) + + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), + ) + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)), + StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0), + ) + + def test_recv_fin_out_of_order(self): + stream = QuicStream(stream_id=0) + + # add data at offset 8 with FIN + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)), + None, + ) + + # add data at offset 0 + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"0123456789012345", end_stream=True, stream_id=0), + ) + + def test_recv_fin_then_data(self): + stream = QuicStream(stream_id=0) + stream.add_frame(QuicStreamFrame(offset=0, data=b"", fin=True)) + with self.assertRaises(Exception) as cm: + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")) + self.assertEqual(str(cm.exception), "Data received beyond FIN") + + def test_recv_fin_twice(self): + stream = QuicStream(stream_id=0) + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), + StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0), + ) + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)), + StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0), + ) + + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)), + StreamDataReceived(data=b"", end_stream=True, stream_id=0), + ) + + def test_recv_fin_without_data(self): + stream = QuicStream(stream_id=0) + self.assertEqual( + stream.add_frame(QuicStreamFrame(offset=0, data=b"", fin=True)), + StreamDataReceived(data=b"", end_stream=True, stream_id=0), + ) + + def test_send_data(self): + stream = QuicStream() + self.assertEqual(stream.next_send_offset, 0) + + # nothing to send yet + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write data + stream.write(b"0123456789012345") + self.assertEqual(list(stream._send_pending), [range(0, 16)]) + self.assertEqual(stream.next_send_offset, 0) + + # send a chunk + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"01234567") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 0) + self.assertEqual(list(stream._send_pending), [range(8, 16)]) + self.assertEqual(stream.next_send_offset, 8) + + # send another chunk + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"89012345") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 8) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # nothing more to send + frame = stream.get_frame(8) + self.assertIsNone(frame) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # first chunk gets acknowledged + stream.on_data_delivery(QuicDeliveryState.ACKED, 0, 8) + + # second chunk gets acknowledged + stream.on_data_delivery(QuicDeliveryState.ACKED, 8, 16) + + def test_send_data_and_fin(self): + stream = QuicStream() + + # nothing to send yet + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write data and EOF + stream.write(b"0123456789012345", end_stream=True) + self.assertEqual(list(stream._send_pending), [range(0, 16)]) + self.assertEqual(stream.next_send_offset, 0) + + # send a chunk + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"01234567") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 0) + self.assertEqual(stream.next_send_offset, 8) + + # send another chunk + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"89012345") + self.assertTrue(frame.fin) + self.assertEqual(frame.offset, 8) + self.assertEqual(stream.next_send_offset, 16) + + # nothing more to send + frame = stream.get_frame(8) + self.assertIsNone(frame) + self.assertEqual(stream.next_send_offset, 16) + + def test_send_data_lost(self): + stream = QuicStream() + + # nothing to send yet + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write data and EOF + stream.write(b"0123456789012345", end_stream=True) + self.assertEqual(list(stream._send_pending), [range(0, 16)]) + self.assertEqual(stream.next_send_offset, 0) + + # send a chunk + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0) + ) + self.assertEqual(list(stream._send_pending), [range(8, 16)]) + self.assertEqual(stream.next_send_offset, 8) + + # send another chunk + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8) + ) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # nothing more to send + self.assertIsNone(stream.get_frame(8)) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # a chunk gets lost + stream.on_data_delivery(QuicDeliveryState.LOST, 0, 8) + self.assertEqual(list(stream._send_pending), [range(0, 8)]) + self.assertEqual(stream.next_send_offset, 0) + + # send chunk again + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0) + ) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + def test_send_data_lost_fin(self): + stream = QuicStream() + + # nothing to send yet + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write data and EOF + stream.write(b"0123456789012345", end_stream=True) + self.assertEqual(list(stream._send_pending), [range(0, 16)]) + self.assertEqual(stream.next_send_offset, 0) + + # send a chunk + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0) + ) + self.assertEqual(list(stream._send_pending), [range(8, 16)]) + self.assertEqual(stream.next_send_offset, 8) + + # send another chunk + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8) + ) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # nothing more to send + self.assertIsNone(stream.get_frame(8)) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + # a chunk gets lost + stream.on_data_delivery(QuicDeliveryState.LOST, 8, 16) + self.assertEqual(list(stream._send_pending), [range(8, 16)]) + self.assertEqual(stream.next_send_offset, 8) + + # send chunk again + self.assertEqual( + stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8) + ) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 16) + + def test_send_blocked(self): + stream = QuicStream() + max_offset = 12 + + # nothing to send yet + frame = stream.get_frame(8, max_offset) + self.assertIsNone(frame) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 0) + + # write data, send a chunk + stream.write(b"0123456789012345") + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"01234567") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 0) + self.assertEqual(list(stream._send_pending), [range(8, 16)]) + self.assertEqual(stream.next_send_offset, 8) + + # send is limited by peer + frame = stream.get_frame(8, max_offset) + self.assertEqual(frame.data, b"8901") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 8) + self.assertEqual(list(stream._send_pending), [range(12, 16)]) + self.assertEqual(stream.next_send_offset, 12) + + # unable to send, blocked + frame = stream.get_frame(8, max_offset) + self.assertIsNone(frame) + self.assertEqual(list(stream._send_pending), [range(12, 16)]) + self.assertEqual(stream.next_send_offset, 12) + + # write more data, still blocked + stream.write(b"abcdefgh") + frame = stream.get_frame(8, max_offset) + self.assertIsNone(frame) + self.assertEqual(list(stream._send_pending), [range(12, 24)]) + self.assertEqual(stream.next_send_offset, 12) + + # peer raises limit, send some data + max_offset += 8 + frame = stream.get_frame(8, max_offset) + self.assertEqual(frame.data, b"2345abcd") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 12) + self.assertEqual(list(stream._send_pending), [range(20, 24)]) + self.assertEqual(stream.next_send_offset, 20) + + # peer raises limit again, send remaining data + max_offset += 8 + frame = stream.get_frame(8, max_offset) + self.assertEqual(frame.data, b"efgh") + self.assertFalse(frame.fin) + self.assertEqual(frame.offset, 20) + self.assertEqual(list(stream._send_pending), []) + self.assertEqual(stream.next_send_offset, 24) + + # nothing more to send + frame = stream.get_frame(8, max_offset) + self.assertIsNone(frame) + + def test_send_fin_only(self): + stream = QuicStream() + + # nothing to send yet + self.assertTrue(stream.send_buffer_is_empty) + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write EOF + stream.write(b"", end_stream=True) + self.assertFalse(stream.send_buffer_is_empty) + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"") + self.assertTrue(frame.fin) + self.assertEqual(frame.offset, 0) + + # nothing more to send + self.assertFalse(stream.send_buffer_is_empty) # FIXME? + frame = stream.get_frame(8) + self.assertIsNone(frame) + self.assertTrue(stream.send_buffer_is_empty) + + def test_send_fin_only_despite_blocked(self): + stream = QuicStream() + + # nothing to send yet + self.assertTrue(stream.send_buffer_is_empty) + frame = stream.get_frame(8) + self.assertIsNone(frame) + + # write EOF + stream.write(b"", end_stream=True) + self.assertFalse(stream.send_buffer_is_empty) + frame = stream.get_frame(8) + self.assertEqual(frame.data, b"") + self.assertTrue(frame.fin) + self.assertEqual(frame.offset, 0) + + # nothing more to send + self.assertFalse(stream.send_buffer_is_empty) # FIXME? + frame = stream.get_frame(8) + self.assertIsNone(frame) + self.assertTrue(stream.send_buffer_is_empty) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/test_tls.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_tls.py new file mode 100644 index 000000000000..72dd56deb8ba --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/test_tls.py @@ -0,0 +1,1400 @@ +import binascii +import datetime +import ssl +from unittest import TestCase +from unittest.mock import patch + +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec + +from aioquic import tls +from aioquic.buffer import Buffer, BufferReadError +from aioquic.quic.configuration import QuicConfiguration +from aioquic.tls import ( + Certificate, + CertificateVerify, + ClientHello, + Context, + EncryptedExtensions, + Finished, + NewSessionTicket, + ServerHello, + State, + load_pem_x509_certificates, + pull_block, + pull_certificate, + pull_certificate_verify, + pull_client_hello, + pull_encrypted_extensions, + pull_finished, + pull_new_session_ticket, + pull_server_hello, + push_certificate, + push_certificate_verify, + push_client_hello, + push_encrypted_extensions, + push_finished, + push_new_session_ticket, + push_server_hello, + verify_certificate, +) + +from .utils import ( + SERVER_CACERTFILE, + SERVER_CERTFILE, + SERVER_KEYFILE, + generate_ec_certificate, + load, +) + +CERTIFICATE_DATA = load("tls_certificate.bin")[11:-2] +CERTIFICATE_VERIFY_SIGNATURE = load("tls_certificate_verify.bin")[-384:] + +CLIENT_QUIC_TRANSPORT_PARAMETERS = binascii.unhexlify( + b"ff0000110031000500048010000000060004801000000007000480100000000" + b"4000481000000000100024258000800024064000a00010a" +) + +SERVER_QUIC_TRANSPORT_PARAMETERS = binascii.unhexlify( + b"ff00001104ff000011004500050004801000000006000480100000000700048" + b"010000000040004810000000001000242580002001000000000000000000000" + b"000000000000000800024064000a00010a" +) + +SERVER_QUIC_TRANSPORT_PARAMETERS_2 = binascii.unhexlify( + b"0057000600048000ffff000500048000ffff00020010c5ac410fbdd4fe6e2c1" + b"42279f231e8e0000a000103000400048005fffa000b000119000100026710ff" + b"42000c5c067f27e39321c63e28e7c90003000247e40008000106" +) + +SERVER_QUIC_TRANSPORT_PARAMETERS_3 = binascii.unhexlify( + b"0054000200100dcb50a442513295b4679baf04cb5effff8a0009c8afe72a6397" + b"255407000600048000ffff0008000106000400048005fffa000500048000ffff" + b"0003000247e4000a000103000100026710000b000119" +) + + +class BufferTest(TestCase): + def test_pull_block_truncated(self): + buf = Buffer(capacity=0) + with self.assertRaises(BufferReadError): + with pull_block(buf, 1): + pass + + +def create_buffers(): + return { + tls.Epoch.INITIAL: Buffer(capacity=4096), + tls.Epoch.HANDSHAKE: Buffer(capacity=4096), + tls.Epoch.ONE_RTT: Buffer(capacity=4096), + } + + +def merge_buffers(buffers): + return b"".join(x.data for x in buffers.values()) + + +def reset_buffers(buffers): + for k in buffers.keys(): + buffers[k].seek(0) + + +class ContextTest(TestCase): + def create_client( + self, alpn_protocols=None, cadata=None, cafile=SERVER_CACERTFILE, **kwargs + ): + client = Context( + alpn_protocols=alpn_protocols, + cadata=cadata, + cafile=cafile, + is_client=True, + **kwargs + ) + client.handshake_extensions = [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + CLIENT_QUIC_TRANSPORT_PARAMETERS, + ) + ] + self.assertEqual(client.state, State.CLIENT_HANDSHAKE_START) + return client + + def create_server(self, alpn_protocols=None): + configuration = QuicConfiguration(is_client=False) + configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE) + + server = Context( + alpn_protocols=alpn_protocols, is_client=False, max_early_data=0xFFFFFFFF + ) + server.certificate = configuration.certificate + server.certificate_private_key = configuration.private_key + server.handshake_extensions = [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + SERVER_QUIC_TRANSPORT_PARAMETERS, + ) + ] + self.assertEqual(server.state, State.SERVER_EXPECT_CLIENT_HELLO) + return server + + def test_client_unexpected_message(self): + client = self.create_client() + + client.state = State.CLIENT_EXPECT_SERVER_HELLO + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + client.state = State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + client.state = State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + client.state = State.CLIENT_EXPECT_CERTIFICATE_VERIFY + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + client.state = State.CLIENT_EXPECT_FINISHED + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + client.state = State.CLIENT_POST_HANDSHAKE + with self.assertRaises(tls.AlertUnexpectedMessage): + client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + def test_client_bad_certificate_verify_data(self): + client = self.create_client() + server = self.create_server() + + # send client hello + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, certificate, certificate verify, finished + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + reset_buffers(server_buf) + + # mess with certificate verify + client_input = client_input[:-56] + bytes(4) + client_input[-52:] + + # handle server hello, encrypted extensions, certificate, certificate verify, finished + with self.assertRaises(tls.AlertDecryptError): + client.handle_message(client_input, client_buf) + + def test_client_bad_finished_verify_data(self): + client = self.create_client() + server = self.create_server() + + # send client hello + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, certificate, certificate verify, finished + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + reset_buffers(server_buf) + + # mess with finished verify data + client_input = client_input[:-4] + bytes(4) + + # handle server hello, encrypted extensions, certificate, certificate verify, finished + with self.assertRaises(tls.AlertDecryptError): + client.handle_message(client_input, client_buf) + + def test_server_unexpected_message(self): + server = self.create_server() + + server.state = State.SERVER_EXPECT_CLIENT_HELLO + with self.assertRaises(tls.AlertUnexpectedMessage): + server.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + server.state = State.SERVER_EXPECT_FINISHED + with self.assertRaises(tls.AlertUnexpectedMessage): + server.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + server.state = State.SERVER_POST_HANDSHAKE + with self.assertRaises(tls.AlertUnexpectedMessage): + server.handle_message(b"\x00\x00\x00\x00", create_buffers()) + + def _server_fail_hello(self, client, server): + # send client hello + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # handle client hello + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + + def test_server_unsupported_cipher_suite(self): + client = self.create_client() + client._cipher_suites = [tls.CipherSuite.AES_128_GCM_SHA256] + + server = self.create_server() + server._cipher_suites = [tls.CipherSuite.AES_256_GCM_SHA384] + + with self.assertRaises(tls.AlertHandshakeFailure) as cm: + self._server_fail_hello(client, server) + self.assertEqual(str(cm.exception), "No supported cipher suite") + + def test_server_unsupported_signature_algorithm(self): + client = self.create_client() + client._signature_algorithms = [tls.SignatureAlgorithm.ED448] + + server = self.create_server() + + with self.assertRaises(tls.AlertHandshakeFailure) as cm: + self._server_fail_hello(client, server) + self.assertEqual(str(cm.exception), "No supported signature algorithm") + + def test_server_unsupported_version(self): + client = self.create_client() + client._supported_versions = [tls.TLS_VERSION_1_2] + + server = self.create_server() + + with self.assertRaises(tls.AlertProtocolVersion) as cm: + self._server_fail_hello(client, server) + self.assertEqual(str(cm.exception), "No supported protocol version") + + def test_server_bad_finished_verify_data(self): + client = self.create_client() + server = self.create_server() + + # send client hello + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, certificate, certificate verify, finished + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + reset_buffers(server_buf) + + # handle server hello, encrypted extensions, certificate, certificate verify, finished + # send finished + client.handle_message(client_input, client_buf) + self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # mess with finished verify data + server_input = server_input[:-4] + bytes(4) + + # handle finished + with self.assertRaises(tls.AlertDecryptError): + server.handle_message(server_input, server_buf) + + def _handshake(self, client, server): + # send client hello + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + self.assertGreaterEqual(len(server_input), 213) + self.assertLessEqual(len(server_input), 358) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, certificate, certificate verify, finished, (session ticket) + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + self.assertGreaterEqual(len(client_input), 600) + self.assertLessEqual(len(client_input), 2316) + + reset_buffers(server_buf) + + # handle server hello, encrypted extensions, certificate, certificate verify, finished, (session ticket) + # send finished + client.handle_message(client_input, client_buf) + self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) + server_input = merge_buffers(client_buf) + self.assertEqual(len(server_input), 52) + reset_buffers(client_buf) + + # handle finished + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) + client_input = merge_buffers(server_buf) + self.assertEqual(len(client_input), 0) + + # check keys match + self.assertEqual(client._dec_key, server._enc_key) + self.assertEqual(client._enc_key, server._dec_key) + + # check cipher suite + self.assertEqual( + client.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 + ) + self.assertEqual( + server.key_schedule.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384 + ) + + def test_handshake(self): + client = self.create_client() + server = self.create_server() + + self._handshake(client, server) + + # check ALPN matches + self.assertEqual(client.alpn_negotiated, None) + self.assertEqual(server.alpn_negotiated, None) + + def test_handshake_ecdsa_secp256r1(self): + server = self.create_server() + server.certificate, server.certificate_private_key = generate_ec_certificate( + common_name="example.com", curve=ec.SECP256R1 + ) + + client = self.create_client( + cadata=server.certificate.public_bytes(serialization.Encoding.PEM), + cafile=None, + ) + + self._handshake(client, server) + + # check ALPN matches + self.assertEqual(client.alpn_negotiated, None) + self.assertEqual(server.alpn_negotiated, None) + + def test_handshake_with_alpn(self): + client = self.create_client(alpn_protocols=["hq-20"]) + server = self.create_server(alpn_protocols=["hq-20", "h3-20"]) + + self._handshake(client, server) + + # check ALPN matches + self.assertEqual(client.alpn_negotiated, "hq-20") + self.assertEqual(server.alpn_negotiated, "hq-20") + + def test_handshake_with_alpn_fail(self): + client = self.create_client(alpn_protocols=["hq-20"]) + server = self.create_server(alpn_protocols=["h3-20"]) + + with self.assertRaises(tls.AlertHandshakeFailure) as cm: + self._handshake(client, server) + self.assertEqual(str(cm.exception), "No common ALPN protocols") + + def test_handshake_with_rsa_pkcs1_sha256_signature(self): + client = self.create_client() + client._signature_algorithms = [tls.SignatureAlgorithm.RSA_PKCS1_SHA256] + server = self.create_server() + + self._handshake(client, server) + + def test_handshake_with_certificate_error(self): + client = self.create_client(cafile=None) + server = self.create_server() + + with self.assertRaises(tls.AlertBadCertificate) as cm: + self._handshake(client, server) + self.assertEqual(str(cm.exception), "unable to get local issuer certificate") + + def test_handshake_with_certificate_no_verify(self): + client = self.create_client(cafile=None, verify_mode=ssl.CERT_NONE) + server = self.create_server() + + self._handshake(client, server) + + def test_handshake_with_grease_group(self): + client = self.create_client() + client._supported_groups = [tls.Group.GREASE, tls.Group.SECP256R1] + server = self.create_server() + + self._handshake(client, server) + + def test_handshake_with_x25519(self): + client = self.create_client() + client._supported_groups = [tls.Group.X25519] + server = self.create_server() + + try: + self._handshake(client, server) + except UnsupportedAlgorithm as exc: + self.skipTest(str(exc)) + + def test_handshake_with_x448(self): + client = self.create_client() + client._supported_groups = [tls.Group.X448] + server = self.create_server() + + try: + self._handshake(client, server) + except UnsupportedAlgorithm as exc: + self.skipTest(str(exc)) + + def test_session_ticket(self): + client_tickets = [] + server_tickets = [] + + def client_new_ticket(ticket): + client_tickets.append(ticket) + + def server_get_ticket(label): + for t in server_tickets: + if t.ticket == label: + return t + return None + + def server_new_ticket(ticket): + server_tickets.append(ticket) + + def first_handshake(): + client = self.create_client() + client.new_session_ticket_cb = client_new_ticket + + server = self.create_server() + server.new_session_ticket_cb = server_new_ticket + + self._handshake(client, server) + + # check session resumption was not used + self.assertFalse(client.session_resumed) + self.assertFalse(server.session_resumed) + + # check tickets match + self.assertEqual(len(client_tickets), 1) + self.assertEqual(len(server_tickets), 1) + self.assertEqual(client_tickets[0].ticket, server_tickets[0].ticket) + self.assertEqual( + client_tickets[0].resumption_secret, server_tickets[0].resumption_secret + ) + + def second_handshake(): + client = self.create_client() + client.session_ticket = client_tickets[0] + + server = self.create_server() + server.get_session_ticket_cb = server_get_ticket + + # send client hello with pre_shared_key + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + self.assertGreaterEqual(len(server_input), 383) + self.assertLessEqual(len(server_input), 483) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, finished + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + self.assertEqual(len(client_input), 307) + reset_buffers(server_buf) + + # handle server hello, encrypted extensions, certificate, certificate verify, finished + # send finished + client.handle_message(client_input, client_buf) + self.assertEqual(client.state, State.CLIENT_POST_HANDSHAKE) + server_input = merge_buffers(client_buf) + self.assertEqual(len(server_input), 52) + reset_buffers(client_buf) + + # handle finished + # send new_session_ticket + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_POST_HANDSHAKE) + client_input = merge_buffers(server_buf) + self.assertEqual(len(client_input), 0) + reset_buffers(server_buf) + + # check keys match + self.assertEqual(client._dec_key, server._enc_key) + self.assertEqual(client._enc_key, server._dec_key) + + # check session resumption was used + self.assertTrue(client.session_resumed) + self.assertTrue(server.session_resumed) + + def second_handshake_bad_binder(): + client = self.create_client() + client.session_ticket = client_tickets[0] + + server = self.create_server() + server.get_session_ticket_cb = server_get_ticket + + # send client hello with pre_shared_key + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + self.assertGreaterEqual(len(server_input), 383) + self.assertLessEqual(len(server_input), 483) + reset_buffers(client_buf) + + # tamper with binder + server_input = server_input[:-4] + bytes(4) + + # handle client hello + # send server hello, encrypted extensions, finished + server_buf = create_buffers() + with self.assertRaises(tls.AlertHandshakeFailure) as cm: + server.handle_message(server_input, server_buf) + self.assertEqual(str(cm.exception), "PSK validation failed") + + def second_handshake_bad_pre_shared_key(): + client = self.create_client() + client.session_ticket = client_tickets[0] + + server = self.create_server() + server.get_session_ticket_cb = server_get_ticket + + # send client hello with pre_shared_key + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + self.assertGreaterEqual(len(server_input), 383) + self.assertLessEqual(len(server_input), 483) + reset_buffers(client_buf) + + # handle client hello + # send server hello, encrypted extensions, finished + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + + # tamper with pre_share_key index + buf = server_buf[tls.Epoch.INITIAL] + buf.seek(buf.tell() - 1) + buf.push_uint8(1) + client_input = merge_buffers(server_buf) + self.assertEqual(len(client_input), 307) + reset_buffers(server_buf) + + # handle server hello and bomb + with self.assertRaises(tls.AlertIllegalParameter): + client.handle_message(client_input, client_buf) + + first_handshake() + second_handshake() + second_handshake_bad_binder() + second_handshake_bad_pre_shared_key() + + +class TlsTest(TestCase): + def test_pull_client_hello(self): + buf = Buffer(data=load("tls_client_hello.bin")) + hello = pull_client_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello.random, + binascii.unhexlify( + "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" + ), + ) + self.assertEqual( + hello.session_id, + binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + ) + self.assertEqual( + hello.cipher_suites, + [ + tls.CipherSuite.AES_256_GCM_SHA384, + tls.CipherSuite.AES_128_GCM_SHA256, + tls.CipherSuite.CHACHA20_POLY1305_SHA256, + ], + ) + self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) + + # extensions + self.assertEqual(hello.alpn_protocols, None) + self.assertEqual( + hello.key_share, + [ + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" + "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" + "b0" + ), + ) + ], + ) + self.assertEqual( + hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] + ) + self.assertEqual(hello.server_name, None) + self.assertEqual( + hello.signature_algorithms, + [ + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA1, + ], + ) + self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) + self.assertEqual( + hello.supported_versions, + [ + tls.TLS_VERSION_1_3, + tls.TLS_VERSION_1_3_DRAFT_28, + tls.TLS_VERSION_1_3_DRAFT_27, + tls.TLS_VERSION_1_3_DRAFT_26, + ], + ) + + self.assertEqual( + hello.other_extensions, + [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + CLIENT_QUIC_TRANSPORT_PARAMETERS, + ) + ], + ) + + def test_pull_client_hello_with_alpn(self): + buf = Buffer(data=load("tls_client_hello_with_alpn.bin")) + hello = pull_client_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello.random, + binascii.unhexlify( + "ed575c6fbd599c4dfaabd003dca6e860ccdb0e1782c1af02e57bf27cb6479b76" + ), + ) + self.assertEqual(hello.session_id, b"") + self.assertEqual( + hello.cipher_suites, + [ + tls.CipherSuite.AES_128_GCM_SHA256, + tls.CipherSuite.AES_256_GCM_SHA384, + tls.CipherSuite.CHACHA20_POLY1305_SHA256, + tls.CipherSuite.EMPTY_RENEGOTIATION_INFO_SCSV, + ], + ) + self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) + + # extensions + self.assertEqual(hello.alpn_protocols, ["h3-19"]) + self.assertEqual(hello.early_data, False) + self.assertEqual( + hello.key_share, + [ + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "048842315c437bb0ce2929c816fee4e942ec5cb6db6a6b9bf622680188ebb0d4" + "b652e69033f71686aa01cbc79155866e264c9f33f45aa16b0dfa10a222e3a669" + "22" + ), + ) + ], + ) + self.assertEqual( + hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] + ) + self.assertEqual(hello.server_name, "cloudflare-quic.com") + self.assertEqual( + hello.signature_algorithms, + [ + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP384R1_SHA384, + tls.SignatureAlgorithm.ECDSA_SECP521R1_SHA512, + tls.SignatureAlgorithm.ED25519, + tls.SignatureAlgorithm.ED448, + tls.SignatureAlgorithm.RSA_PSS_PSS_SHA256, + tls.SignatureAlgorithm.RSA_PSS_PSS_SHA384, + tls.SignatureAlgorithm.RSA_PSS_PSS_SHA512, + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA384, + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA512, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA384, + tls.SignatureAlgorithm.RSA_PKCS1_SHA512, + ], + ) + self.assertEqual( + hello.supported_groups, + [ + tls.Group.SECP256R1, + tls.Group.X25519, + tls.Group.SECP384R1, + tls.Group.SECP521R1, + ], + ) + self.assertEqual(hello.supported_versions, [tls.TLS_VERSION_1_3]) + + # serialize + buf = Buffer(1000) + push_client_hello(buf, hello) + self.assertEqual(len(buf.data), len(load("tls_client_hello_with_alpn.bin"))) + + def test_pull_client_hello_with_psk(self): + buf = Buffer(data=load("tls_client_hello_with_psk.bin")) + hello = pull_client_hello(buf) + + self.assertEqual(hello.early_data, True) + self.assertEqual( + hello.pre_shared_key, + tls.OfferedPsks( + identities=[ + ( + binascii.unhexlify( + "fab3dc7d79f35ea53e9adf21150e601591a750b80cde0cd167fef6e0cdbc032a" + "c4161fc5c5b66679de49524bd5624c50d71ba3e650780a4bfe402d6a06a00525" + "0b5dc52085233b69d0dd13924cc5c713a396784ecafc59f5ea73c1585d79621b" + "8a94e4f2291b17427d5185abf4a994fca74ee7a7f993a950c71003fc7cf8" + ), + 2067156378, + ) + ], + binders=[ + binascii.unhexlify( + "1788ad43fdff37cfc628f24b6ce7c8c76180705380da17da32811b5bae4e78" + "d7aaaf65a9b713872f2bb28818ca1a6b01" + ) + ], + ), + ) + + self.assertTrue(buf.eof()) + + # serialize + buf = Buffer(1000) + push_client_hello(buf, hello) + self.assertEqual(buf.data, load("tls_client_hello_with_psk.bin")) + + def test_pull_client_hello_with_sni(self): + buf = Buffer(data=load("tls_client_hello_with_sni.bin")) + hello = pull_client_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello.random, + binascii.unhexlify( + "987d8934140b0a42cc5545071f3f9f7f61963d7b6404eb674c8dbe513604346b" + ), + ) + self.assertEqual( + hello.session_id, + binascii.unhexlify( + "26b19bdd30dbf751015a3a16e13bd59002dfe420b799d2a5cd5e11b8fa7bcb66" + ), + ) + self.assertEqual( + hello.cipher_suites, + [ + tls.CipherSuite.AES_256_GCM_SHA384, + tls.CipherSuite.AES_128_GCM_SHA256, + tls.CipherSuite.CHACHA20_POLY1305_SHA256, + ], + ) + self.assertEqual(hello.compression_methods, [tls.CompressionMethod.NULL]) + + # extensions + self.assertEqual(hello.alpn_protocols, None) + self.assertEqual( + hello.key_share, + [ + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "04b62d70f907c814cd65d0f73b8b991f06b70c77153f548410a191d2b19764a2" + "ecc06065a480efa9e1f10c8da6e737d5bfc04be3f773e20a0c997f51b5621280" + "40" + ), + ) + ], + ) + self.assertEqual( + hello.psk_key_exchange_modes, [tls.PskKeyExchangeMode.PSK_DHE_KE] + ) + self.assertEqual(hello.server_name, "cloudflare-quic.com") + self.assertEqual( + hello.signature_algorithms, + [ + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA1, + ], + ) + self.assertEqual(hello.supported_groups, [tls.Group.SECP256R1]) + self.assertEqual( + hello.supported_versions, + [ + tls.TLS_VERSION_1_3, + tls.TLS_VERSION_1_3_DRAFT_28, + tls.TLS_VERSION_1_3_DRAFT_27, + tls.TLS_VERSION_1_3_DRAFT_26, + ], + ) + + self.assertEqual( + hello.other_extensions, + [ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + CLIENT_QUIC_TRANSPORT_PARAMETERS, + ) + ], + ) + + # serialize + buf = Buffer(1000) + push_client_hello(buf, hello) + self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin")) + + def test_push_client_hello(self): + hello = ClientHello( + random=binascii.unhexlify( + "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" + ), + session_id=binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + cipher_suites=[ + tls.CipherSuite.AES_256_GCM_SHA384, + tls.CipherSuite.AES_128_GCM_SHA256, + tls.CipherSuite.CHACHA20_POLY1305_SHA256, + ], + compression_methods=[tls.CompressionMethod.NULL], + key_share=[ + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" + "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" + "b0" + ), + ) + ], + psk_key_exchange_modes=[tls.PskKeyExchangeMode.PSK_DHE_KE], + signature_algorithms=[ + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA1, + ], + supported_groups=[tls.Group.SECP256R1], + supported_versions=[ + tls.TLS_VERSION_1_3, + tls.TLS_VERSION_1_3_DRAFT_28, + tls.TLS_VERSION_1_3_DRAFT_27, + tls.TLS_VERSION_1_3_DRAFT_26, + ], + other_extensions=[ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + CLIENT_QUIC_TRANSPORT_PARAMETERS, + ) + ], + ) + + buf = Buffer(1000) + push_client_hello(buf, hello) + self.assertEqual(buf.data, load("tls_client_hello.bin")) + + def test_pull_server_hello(self): + buf = Buffer(data=load("tls_server_hello.bin")) + hello = pull_server_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello.random, + binascii.unhexlify( + "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" + ), + ) + self.assertEqual( + hello.session_id, + binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + ) + self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) + self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) + self.assertEqual( + hello.key_share, + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" + "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" + "b2" + ), + ), + ) + self.assertEqual(hello.pre_shared_key, None) + self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3) + + def test_pull_server_hello_with_psk(self): + buf = Buffer(data=load("tls_server_hello_with_psk.bin")) + hello = pull_server_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello.random, + binascii.unhexlify( + "ccbaaf04fc1bd5143b2cc6b97520cf37d91470dbfc8127131a7bf0f941e3a137" + ), + ) + self.assertEqual( + hello.session_id, + binascii.unhexlify( + "9483e7e895d0f4cec17086b0849601c0632662cd764e828f2f892f4c4b7771b0" + ), + ) + self.assertEqual(hello.cipher_suite, tls.CipherSuite.AES_256_GCM_SHA384) + self.assertEqual(hello.compression_method, tls.CompressionMethod.NULL) + self.assertEqual( + hello.key_share, + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "0485d7cecbebfc548fc657bf51b8e8da842a4056b164a27f7702ca318c16e488" + "18b6409593b15c6649d6f459387a53128b164178adc840179aad01d36ce95d62" + "76" + ), + ), + ) + self.assertEqual(hello.pre_shared_key, 0) + self.assertEqual(hello.supported_version, tls.TLS_VERSION_1_3) + + # serialize + buf = Buffer(1000) + push_server_hello(buf, hello) + self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin")) + + def test_pull_server_hello_with_unknown_extension(self): + buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin")) + hello = pull_server_hello(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + hello, + ServerHello( + random=binascii.unhexlify( + "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" + ), + session_id=binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384, + compression_method=tls.CompressionMethod.NULL, + key_share=( + tls.Group.SECP256R1, + binascii.unhexlify( + "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" + "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" + "b2" + ), + ), + supported_version=tls.TLS_VERSION_1_3, + other_extensions=[(12345, b"foo")], + ), + ) + + # serialize + buf = Buffer(1000) + push_server_hello(buf, hello) + self.assertEqual(buf.data, load("tls_server_hello_with_unknown_extension.bin")) + + def test_push_server_hello(self): + hello = ServerHello( + random=binascii.unhexlify( + "ada85271d19680c615ea7336519e3fdf6f1e26f3b1075ee1de96ffa8884e8280" + ), + session_id=binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + cipher_suite=tls.CipherSuite.AES_256_GCM_SHA384, + compression_method=tls.CompressionMethod.NULL, + key_share=( + tls.Group.SECP256R1, + binascii.unhexlify( + "048b27d0282242d84b7fcc02a9c4f13eca0329e3c7029aa34a33794e6e7ba189" + "5cca1c503bf0378ac6937c354912116ff3251026bca1958d7f387316c83ae6cf" + "b2" + ), + ), + supported_version=tls.TLS_VERSION_1_3, + ) + + buf = Buffer(1000) + push_server_hello(buf, hello) + self.assertEqual(buf.data, load("tls_server_hello.bin")) + + def test_pull_new_session_ticket(self): + buf = Buffer(data=load("tls_new_session_ticket.bin")) + new_session_ticket = pull_new_session_ticket(buf) + self.assertIsNotNone(new_session_ticket) + self.assertTrue(buf.eof()) + + self.assertEqual( + new_session_ticket, + NewSessionTicket( + ticket_lifetime=86400, + ticket_age_add=3303452425, + ticket_nonce=b"", + ticket=binascii.unhexlify( + "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81" + ), + max_early_data_size=0xFFFFFFFF, + ), + ) + + # serialize + buf = Buffer(100) + push_new_session_ticket(buf, new_session_ticket) + self.assertEqual(buf.data, load("tls_new_session_ticket.bin")) + + def test_pull_new_session_ticket_with_unknown_extension(self): + buf = Buffer(data=load("tls_new_session_ticket_with_unknown_extension.bin")) + new_session_ticket = pull_new_session_ticket(buf) + self.assertIsNotNone(new_session_ticket) + self.assertTrue(buf.eof()) + + self.assertEqual( + new_session_ticket, + NewSessionTicket( + ticket_lifetime=86400, + ticket_age_add=3303452425, + ticket_nonce=b"", + ticket=binascii.unhexlify( + "dbe6f1a77a78c0426bfa607cd0d02b350247d90618704709596beda7e962cc81" + ), + max_early_data_size=0xFFFFFFFF, + other_extensions=[(12345, b"foo")], + ), + ) + + # serialize + buf = Buffer(100) + push_new_session_ticket(buf, new_session_ticket) + self.assertEqual( + buf.data, load("tls_new_session_ticket_with_unknown_extension.bin") + ) + + def test_encrypted_extensions(self): + data = load("tls_encrypted_extensions.bin") + buf = Buffer(data=data) + extensions = pull_encrypted_extensions(buf) + self.assertIsNotNone(extensions) + self.assertTrue(buf.eof()) + + self.assertEqual( + extensions, + EncryptedExtensions( + other_extensions=[ + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + SERVER_QUIC_TRANSPORT_PARAMETERS, + ) + ] + ), + ) + + # serialize + buf = Buffer(capacity=100) + push_encrypted_extensions(buf, extensions) + self.assertEqual(buf.data, data) + + def test_encrypted_extensions_with_alpn(self): + data = load("tls_encrypted_extensions_with_alpn.bin") + buf = Buffer(data=data) + extensions = pull_encrypted_extensions(buf) + self.assertIsNotNone(extensions) + self.assertTrue(buf.eof()) + + self.assertEqual( + extensions, + EncryptedExtensions( + alpn_protocol="hq-20", + other_extensions=[ + (tls.ExtensionType.SERVER_NAME, b""), + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + SERVER_QUIC_TRANSPORT_PARAMETERS_2, + ), + ], + ), + ) + + # serialize + buf = Buffer(115) + push_encrypted_extensions(buf, extensions) + self.assertTrue(buf.eof()) + + def test_pull_encrypted_extensions_with_alpn_and_early_data(self): + buf = Buffer(data=load("tls_encrypted_extensions_with_alpn_and_early_data.bin")) + extensions = pull_encrypted_extensions(buf) + self.assertIsNotNone(extensions) + self.assertTrue(buf.eof()) + + self.assertEqual( + extensions, + EncryptedExtensions( + alpn_protocol="hq-20", + early_data=True, + other_extensions=[ + (tls.ExtensionType.SERVER_NAME, b""), + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, + SERVER_QUIC_TRANSPORT_PARAMETERS_3, + ), + ], + ), + ) + + # serialize + buf = Buffer(116) + push_encrypted_extensions(buf, extensions) + self.assertTrue(buf.eof()) + + def test_pull_certificate(self): + buf = Buffer(data=load("tls_certificate.bin")) + certificate = pull_certificate(buf) + self.assertTrue(buf.eof()) + + self.assertEqual(certificate.request_context, b"") + self.assertEqual(certificate.certificates, [(CERTIFICATE_DATA, b"")]) + + def test_push_certificate(self): + certificate = Certificate( + request_context=b"", certificates=[(CERTIFICATE_DATA, b"")] + ) + + buf = Buffer(1600) + push_certificate(buf, certificate) + self.assertEqual(buf.data, load("tls_certificate.bin")) + + def test_pull_certificate_verify(self): + buf = Buffer(data=load("tls_certificate_verify.bin")) + verify = pull_certificate_verify(buf) + self.assertTrue(buf.eof()) + + self.assertEqual(verify.algorithm, tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256) + self.assertEqual(verify.signature, CERTIFICATE_VERIFY_SIGNATURE) + + def test_push_certificate_verify(self): + verify = CertificateVerify( + algorithm=tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + signature=CERTIFICATE_VERIFY_SIGNATURE, + ) + + buf = Buffer(400) + push_certificate_verify(buf, verify) + self.assertEqual(buf.data, load("tls_certificate_verify.bin")) + + def test_pull_finished(self): + buf = Buffer(data=load("tls_finished.bin")) + finished = pull_finished(buf) + self.assertTrue(buf.eof()) + + self.assertEqual( + finished.verify_data, + binascii.unhexlify( + "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68" + ), + ) + + def test_push_finished(self): + finished = Finished( + verify_data=binascii.unhexlify( + "f157923234ff9a4921aadb2e0ec7b1a30fce73fb9ec0c4276f9af268f408ec68" + ) + ) + + buf = Buffer(128) + push_finished(buf, finished) + self.assertEqual(buf.data, load("tls_finished.bin")) + + +class VerifyCertificateTest(TestCase): + def test_verify_certificate_chain(self): + with open(SERVER_CERTFILE, "rb") as fp: + certificate = load_pem_x509_certificates(fp.read())[0] + + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_before + + # fail + with self.assertRaises(tls.AlertBadCertificate) as cm: + verify_certificate(certificate=certificate, server_name="localhost") + self.assertEqual( + str(cm.exception), "unable to get local issuer certificate" + ) + + # ok + verify_certificate( + cafile=SERVER_CACERTFILE, + certificate=certificate, + server_name="localhost", + ) + + def test_verify_certificate_chain_self_signed(self): + certificate, _ = generate_ec_certificate( + common_name="localhost", curve=ec.SECP256R1 + ) + + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_before + + # fail + with self.assertRaises(tls.AlertBadCertificate) as cm: + verify_certificate(certificate=certificate, server_name="localhost") + self.assertEqual(str(cm.exception), "self signed certificate") + + # ok + verify_certificate( + cadata=certificate.public_bytes(serialization.Encoding.PEM), + certificate=certificate, + server_name="localhost", + ) + + @patch("aioquic.tls.lib.X509_STORE_new") + def test_verify_certificate_chain_internal_error(self, mock_store_new): + mock_store_new.return_value = tls.ffi.NULL + + certificate, _ = generate_ec_certificate( + common_name="localhost", curve=ec.SECP256R1 + ) + + with self.assertRaises(tls.AlertInternalError) as cm: + verify_certificate( + cadata=certificate.public_bytes(serialization.Encoding.PEM), + certificate=certificate, + server_name="localhost", + ) + self.assertEqual(str(cm.exception), "OpenSSL call to X509_store_new failed") + + def test_verify_dates(self): + certificate, _ = generate_ec_certificate( + common_name="example.com", curve=ec.SECP256R1 + ) + cadata = certificate.public_bytes(serialization.Encoding.PEM) + + #  too early + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = ( + certificate.not_valid_before - datetime.timedelta(seconds=1) + ) + with self.assertRaises(tls.AlertCertificateExpired) as cm: + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + self.assertEqual(str(cm.exception), "Certificate is not valid yet") + + # valid + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_before + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_after + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + + # too late + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_after + datetime.timedelta( + seconds=1 + ) + with self.assertRaises(tls.AlertCertificateExpired) as cm: + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + self.assertEqual(str(cm.exception), "Certificate is no longer valid") + + def test_verify_subject(self): + certificate, _ = generate_ec_certificate( + common_name="example.com", curve=ec.SECP256R1 + ) + cadata = certificate.public_bytes(serialization.Encoding.PEM) + + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_before + + # valid + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + + # invalid + with self.assertRaises(tls.AlertBadCertificate) as cm: + verify_certificate( + cadata=cadata, + certificate=certificate, + server_name="test.example.com", + ) + self.assertEqual( + str(cm.exception), + "hostname 'test.example.com' doesn't match 'example.com'", + ) + + with self.assertRaises(tls.AlertBadCertificate) as cm: + verify_certificate( + cadata=cadata, certificate=certificate, server_name="acme.com" + ) + self.assertEqual( + str(cm.exception), "hostname 'acme.com' doesn't match 'example.com'" + ) + + def test_verify_subject_with_subjaltname(self): + certificate, _ = generate_ec_certificate( + alternative_names=["*.example.com", "example.com"], + common_name="example.com", + curve=ec.SECP256R1, + ) + cadata = certificate.public_bytes(serialization.Encoding.PEM) + + with patch("aioquic.tls.utcnow") as mock_utcnow: + mock_utcnow.return_value = certificate.not_valid_before + + # valid + verify_certificate( + cadata=cadata, certificate=certificate, server_name="example.com" + ) + verify_certificate( + cadata=cadata, certificate=certificate, server_name="test.example.com" + ) + + # invalid + with self.assertRaises(tls.AlertBadCertificate) as cm: + verify_certificate( + cadata=cadata, certificate=certificate, server_name="acme.com" + ) + self.assertEqual( + str(cm.exception), + "hostname 'acme.com' doesn't match either of '*.example.com', 'example.com'", + ) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate.bin new file mode 100644 index 000000000000..e9bde6290548 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate_verify.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate_verify.bin new file mode 100644 index 000000000000..ff083950b538 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate_verify.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello.bin new file mode 100644 index 000000000000..1140aecfc69f Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_alpn.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_alpn.bin new file mode 100644 index 000000000000..9113a3e75e14 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_alpn.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_psk.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_psk.bin new file mode 100644 index 000000000000..cf695c3a7b37 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_psk.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_sni.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_sni.bin new file mode 100644 index 000000000000..b568a86167ff Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_sni.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions.bin new file mode 100644 index 000000000000..33546c70c49c Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn.bin new file mode 100644 index 000000000000..8d8436a7b338 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn_and_early_data.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn_and_early_data.bin new file mode 100644 index 000000000000..7e9069f6be3d Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn_and_early_data.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_finished.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_finished.bin new file mode 100644 index 000000000000..a64bb2f2a9d2 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_finished.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket.bin new file mode 100644 index 000000000000..6439b4731e04 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket_with_unknown_extension.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket_with_unknown_extension.bin new file mode 100644 index 000000000000..83792890ace7 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_new_session_ticket_with_unknown_extension.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello.bin new file mode 100644 index 000000000000..20a6c51fe560 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_psk.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_psk.bin new file mode 100644 index 000000000000..deccdf59ab95 Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_psk.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_unknown_extension.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_unknown_extension.bin new file mode 100644 index 000000000000..dbe5b7073fdb Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/tls_server_hello_with_unknown_extension.bin differ diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/utils.py b/testing/web-platform/tests/tools/third_party/aioquic/tests/utils.py new file mode 100644 index 000000000000..60bab2b68a24 --- /dev/null +++ b/testing/web-platform/tests/tools/third_party/aioquic/tests/utils.py @@ -0,0 +1,71 @@ +import asyncio +import datetime +import logging +import os +import sys + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec + + +def generate_ec_certificate(common_name, curve=ec.SECP256R1, alternative_names=[]): + key = ec.generate_private_key(backend=default_backend(), curve=curve) + + subject = issuer = x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, common_name)] + ) + + builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=10)) + ) + if alternative_names: + builder = builder.add_extension( + x509.SubjectAlternativeName( + [x509.DNSName(name) for name in alternative_names] + ), + critical=False, + ) + cert = builder.sign(key, hashes.SHA256(), default_backend()) + return cert, key + + +def load(name): + path = os.path.join(os.path.dirname(__file__), name) + with open(path, "rb") as fp: + return fp.read() + + +def run(coro): + return asyncio.get_event_loop().run_until_complete(coro) + + +SERVER_CACERTFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem") +SERVER_CERTFILE = os.path.join(os.path.dirname(__file__), "ssl_cert.pem") +SERVER_CERTFILE_WITH_CHAIN = os.path.join( + os.path.dirname(__file__), "ssl_cert_with_chain.pem" +) +SERVER_KEYFILE = os.path.join(os.path.dirname(__file__), "ssl_key.pem") +SKIP_TESTS = frozenset(os.environ.get("AIOQUIC_SKIP_TESTS", "").split(",")) + +if os.environ.get("AIOQUIC_DEBUG"): + logging.basicConfig(level=logging.DEBUG) + +if ( + sys.platform == "win32" + and sys.version_info.major == 3 + and sys.version_info.minor == 8 +): + # Python 3.8 uses ProactorEventLoop by default, + # which breaks UDP / IPv6 support, see: + # + # https://bugs.python.org/issue39148 + + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) diff --git a/testing/web-platform/tests/tools/third_party/aioquic/tests/version_negotiation.bin b/testing/web-platform/tests/tools/third_party/aioquic/tests/version_negotiation.bin new file mode 100644 index 000000000000..68509c9099ce Binary files /dev/null and b/testing/web-platform/tests/tools/third_party/aioquic/tests/version_negotiation.bin differ diff --git a/testing/web-platform/tests/tools/wpt/paths b/testing/web-platform/tests/tools/wpt/paths index 35867c4ccbfe..6b09cc8bdabe 100644 --- a/testing/web-platform/tests/tools/wpt/paths +++ b/testing/web-platform/tests/tools/wpt/paths @@ -2,5 +2,6 @@ tools/ci/ tools/docker/ tools/lint/ tools/manifest/ +tools/quic/ tools/serve/ tools/wpt/ diff --git a/testing/web-platform/tests/tools/wpt/wpt.py b/testing/web-platform/tests/tools/wpt/wpt.py index 9abf2cb19008..395db0394da6 100644 --- a/testing/web-platform/tests/tools/wpt/wpt.py +++ b/testing/web-platform/tests/tools/wpt/wpt.py @@ -31,6 +31,7 @@ def load_commands(): "script": props["script"], "parser": props.get("parser"), "parse_known": props.get("parse_known", False), + "py3only": props.get("py3only", False), "help": props.get("help"), "virtualenv": props.get("virtualenv", True), "install": props.get("install", []), @@ -40,7 +41,7 @@ def load_commands(): return rv -def parse_args(argv, commands = load_commands()): +def parse_args(argv, commands=load_commands()): parser = argparse.ArgumentParser() parser.add_argument("--venv", action="store", help="Path to an existing virtualenv to use") parser.add_argument("--skip-venv-setup", action="store_true", @@ -94,7 +95,8 @@ def create_complete_parser(): for command in commands: props = commands[command] - if props["virtualenv"]: + if (props["virtualenv"] and + (not props["py3only"] or sys.version_info.major == 3)): setup_virtualenv(None, False, props) subparser = import_command('wpt', command, props)[1] @@ -108,9 +110,11 @@ def create_complete_parser(): return parser + def venv_dir(): return "_venv" + str(sys.version_info[0]) + def setup_virtualenv(path, skip_venv_setup, props): if skip_venv_setup and path is None: raise ValueError("Must set --venv when --skip-venv-setup is used") diff --git a/testing/web-platform/tests/wpt b/testing/web-platform/tests/wpt index 37ab5409ec79..9930d77d0c9b 100755 --- a/testing/web-platform/tests/wpt +++ b/testing/web-platform/tests/wpt @@ -1,12 +1,21 @@ #!/usr/bin/env python if __name__ == "__main__": + import sys from tools.wpt import wpt - from sys import version_info, argv, exit - args, extra = wpt.parse_args(argv[1:]) + args, extra = wpt.parse_args(sys.argv[1:]) + commands = wpt.load_commands() + py3only = commands[args.command]["py3only"] - if args.py3 and version_info.major < 3: + if (args.py3 or py3only) and sys.version_info.major < 3: from subprocess import call - exit(call(['python3', argv[0]] + [args.command] + extra)) + try: + sys.exit(call(['python3', sys.argv[0]] + [args.command] + extra)) + except OSError as e: + if e.errno == 2: + sys.stderr.write("python3 is needed to run this command\n") + sys.exit(1) + else: + raise else: wpt.main()