Bug 1631100 [wpt PR 23075] - Set up py3-only tools/quic and its venv, a=testonly

Automatic update from web-platform-tests
Set up py3-only tools/quic and its venv (#23075)

* [quic] Set up py3-only tools/quic and its venv

This does not include the actual server, but only sets up the skeleton:
* Introduce a "py3only" to commands.json
* Create the skeleton of tools/quic and set up its venv

The plan is to have `wpt serve` call `wpt servequic` (behind a flag).

For more details, see
https://github.com/web-platform-tests/rfcs/blob/master/rfcs/quic.md

* Squashed 'tools/third_party/aioquic/' content from commit 88f258ae47

git-subtree-dir: tools/third_party/aioquic
git-subtree-split: 88f258ae47b7ec85de5ecf4104665054b969a9bb
--

wpt-commits: 4b870dfffeb7031d3657af7a950dd0ea1304bf10
wpt-pr: 23075
This commit is contained in:
Robert Ma 2020-04-28 11:35:40 +00:00 коммит произвёл moz-wptsync-bot
Родитель 809b86aa8e
Коммит 085b58b0c1
110 изменённых файлов: 20373 добавлений и 6 удалений

Просмотреть файл

Просмотреть файл

@ -0,0 +1,13 @@
{
"servequic": {
"path": "serve.py",
"script": "run",
"parser": "get_parser",
"py3only": true,
"help": "Start the QUIC server",
"virtualenv": true,
"requirements": [
"requirements.txt"
]
}
}

Просмотреть файл

@ -0,0 +1 @@
aioquic==0.8.7

Просмотреть файл

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import argparse
import sys
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="store_true", default=False,
help="turn on verbose logging")
return parser
def run(venv, **kwargs):
# TODO(Hexcles): Replace this with actual implementation.
print(sys.version)
assert sys.version_info.major == 3
import aioquic
print('aioquic: ' + aioquic.__version__)
def main():
kwargs = vars(get_parser().parse_args())
return run(None, **kwargs)
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,12 @@
environment:
CIBW_SKIP: cp27-* cp33-* cp34-*
CIBW_TEST_COMMAND: python -m unittest discover -s {project}/tests
install:
- cmd: C:\Python36-x64\python.exe -m pip install cibuildwheel
build_script:
- cmd: C:\Python36-x64\python.exe -m cibuildwheel --output-dir wheelhouse
- ps: >-
if ($env:APPVEYOR_REPO_TAG -eq "true") {
Invoke-Expression "python -m pip install twine"
Invoke-Expression "python -m twine upload --skip-existing wheelhouse/*.whl"
}

Просмотреть файл

@ -0,0 +1 @@
*.bin binary

Просмотреть файл

@ -0,0 +1,128 @@
name: tests
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Install packages
run: pip install black flake8 isort mypy
- name: Run linters
run: |
flake8 examples src tests
isort -c -df -rc examples src tests
black --check --diff examples src tests
mypy examples src
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python: [3.8, 3.7, 3.6]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python }}
- name: Disable firewall and configure compiler
if: matrix.os == 'macos-latest'
run: |
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --setglobalstate off
echo "::set-env name=AIOQUIC_SKIP_TESTS::chacha20"
echo "::set-env name=CFLAGS::-I/usr/local/opt/openssl/include"
echo "::set-env name=LDFLAGS::-L/usr/local/opt/openssl/lib"
- name: Install OpenSSL
if: matrix.os == 'windows-latest'
run: |
choco install openssl --no-progress
echo "::set-env name=CL::/IC:\Progra~1\OpenSSL-Win64\include"
echo "::set-env name=LINK::/LIBPATH:C:\Progra~1\OpenSSL-Win64\lib"
- name: Run tests
run: |
pip install -U pip setuptools wheel
pip install coverage
pip install .
coverage run -m unittest discover -v
coverage xml
shell: bash
- name: Upload coverage report
uses: codecov/codecov-action@v1
if: matrix.python != 'pypy3'
with:
token: ${{ secrets.CODECOV_TOKEN }}
package-source:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Build source package
run: python setup.py sdist
- name: Upload source package
uses: actions/upload-artifact@v1
with:
name: dist
path: dist/
package-wheel:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Install nasm
if: matrix.os == 'windows-latest'
run: choco install -y nasm
- name: Install nmake
if: matrix.os == 'windows-latest'
run: |
& "C:\Program Files (x86)\Microsoft Visual Studio\Installer\vs_installer.exe" modify `
--installPath "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise" `
--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 --passive --norestart
shell: powershell
- name: Build wheels
env:
CIBW_BEFORE_BUILD: scripts/build-openssl /tmp/vendor
CIBW_BEFORE_BUILD_WINDOWS: scripts\build-openssl.bat C:\cibw\vendor
CIBW_ENVIRONMENT: AIOQUIC_SKIP_TESTS=ipv6,loss CFLAGS=-I/tmp/vendor/include LDFLAGS=-L/tmp/vendor/lib
CIBW_ENVIRONMENT_WINDOWS: AIOQUIC_SKIP_TESTS=ipv6,loss CL="/IC:\cibw\vendor\include" LINK="/LIBPATH:C:\cibw\vendor\lib"
CIBW_SKIP: cp27-* cp33-* cp34-* cp35-* pp27-*
CIBW_TEST_COMMAND: python -m unittest discover -t {project} -s {project}/tests
run: |
pip install cibuildwheel
cibuildwheel --output-dir dist
- name: Upload wheels
uses: actions/upload-artifact@v1
with:
name: dist
path: dist/
publish:
runs-on: ubuntu-latest
needs: [lint, test, package-source, package-wheel]
steps:
- uses: actions/checkout@v2
- uses: actions/download-artifact@v1
with:
name: dist
path: dist/
- name: Publish to PyPI
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/')
uses: pypa/gh-action-pypi-publish@master
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}

10
testing/web-platform/tests/tools/third_party/aioquic/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,10 @@
*.egg-info
*.pyc
*.so
.coverage
.eggs
.mypy_cache
.vscode
/build
/dist
/docs/_build

25
testing/web-platform/tests/tools/third_party/aioquic/LICENSE поставляемый Normal file
Просмотреть файл

@ -0,0 +1,25 @@
Copyright (c) 2019 Jeremy Lainé.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of aioquic nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

Просмотреть файл

@ -0,0 +1,4 @@
include LICENSE
recursive-include docs *.py *.rst Makefile
recursive-include examples *.html *.py
recursive-include tests *.bin *.pem *.py

164
testing/web-platform/tests/tools/third_party/aioquic/README.rst поставляемый Normal file
Просмотреть файл

@ -0,0 +1,164 @@
aioquic
=======
|rtd| |pypi-v| |pypi-pyversions| |pypi-l| |tests| |codecov| |black|
.. |rtd| image:: https://readthedocs.org/projects/aioquic/badge/?version=latest
:target: https://aioquic.readthedocs.io/
.. |pypi-v| image:: https://img.shields.io/pypi/v/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |pypi-l| image:: https://img.shields.io/pypi/l/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |tests| image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg
:target: https://github.com/aiortc/aioquic/actions
.. |codecov| image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg
:target: https://codecov.io/gh/aiortc/aioquic
.. |black| image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/python/black
What is ``aioquic``?
--------------------
``aioquic`` is a library for the QUIC network protocol in Python. It features
a minimal TLS 1.3 implementation, a QUIC stack and an HTTP/3 stack.
QUIC standardisation is not finalised yet, but ``aioquic`` closely tracks the
specification drafts and is regularly tested for interoperability against other
`QUIC implementations`_.
To learn more about ``aioquic`` please `read the documentation`_.
Why should I use ``aioquic``?
-----------------------------
``aioquic`` has been designed to be embedded into Python client and server
libraries wishing to support QUIC and / or HTTP/3. The goal is to provide a
common codebase for Python libraries in the hope of avoiding duplicated effort.
Both the QUIC and the HTTP/3 APIs follow the "bring your own I/O" pattern,
leaving actual I/O operations to the API user. This approach has a number of
advantages including making the code testable and allowing integration with
different concurrency models.
Features
--------
- QUIC stack conforming with draft-27
- HTTP/3 stack conforming with draft-27
- minimal TLS 1.3 implementation
- IPv4 and IPv6 support
- connection migration and NAT rebinding
- logging TLS traffic secrets
- logging QUIC events in QLOG format
- HTTP/3 server push support
Requirements
------------
``aioquic`` requires Python 3.6 or better, and the OpenSSL development headers.
Linux
.....
On Debian/Ubuntu run:
.. code-block:: console
$ sudo apt install libssl-dev python3-dev
On Alpine Linux you will also need the following:
.. code-block:: console
$ sudo apt install bsd-compat-headers libffi-dev
OS X
....
On OS X run:
.. code-block:: console
$ brew install openssl
You will need to set some environment variables to link against OpenSSL:
.. code-block:: console
$ export CFLAGS=-I/usr/local/opt/openssl/include
$ export LDFLAGS=-L/usr/local/opt/openssl/lib
Windows
.......
On Windows the easiest way to install OpenSSL is to use `Chocolatey`_.
.. code-block:: console
> choco install openssl
You will need to set some environment variables to link against OpenSSL:
.. code-block:: console
> $Env:CL = "/IC:\Progra~1\OpenSSL-Win64\include"
> $Env:LINK = "/LIBPATH:C:\Progra~1\OpenSSL-Win64\lib"
Running the examples
--------------------
After checking out the code using git you can run:
.. code-block:: console
$ pip install -e .
$ pip install aiofiles asgiref httpbin starlette wsproto
HTTP/3 server
.............
You can run the example server, which handles both HTTP/0.9 and HTTP/3:
.. code-block:: console
$ python examples/http3_server.py --certificate tests/ssl_cert.pem --private-key tests/ssl_key.pem
HTTP/3 client
.............
You can run the example client to perform an HTTP/3 request:
.. code-block:: console
$ python examples/http3_client.py --ca-certs tests/pycacert.pem https://localhost:4433/
Alternatively you can perform an HTTP/0.9 request:
.. code-block:: console
$ python examples/http3_client.py --ca-certs tests/pycacert.pem --legacy-http https://localhost:4433/
You can also open a WebSocket over HTTP/3:
.. code-block:: console
$ python examples/http3_client.py --ca-certs tests/pycacert.pem wss://localhost:4433/ws
License
-------
``aioquic`` is released under the `BSD license`_.
.. _read the documentation: https://aioquic.readthedocs.io/en/latest/
.. _QUIC implementations: https://github.com/quicwg/base-drafts/wiki/Implementations
.. _cryptography: https://cryptography.io/
.. _Chocolatey: https://chocolatey.org/
.. _BSD license: https://aioquic.readthedocs.io/en/latest/license.html

Просмотреть файл

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SPHINXPROJ = aioquic
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

Просмотреть файл

@ -0,0 +1,32 @@
asyncio API
===========
The asyncio API provides a high-level QUIC API built on top of :mod:`asyncio`,
Python's standard asynchronous I/O framework.
``aioquic`` comes with a selection of examples, including:
- an HTTP/3 client
- an HTTP/3 server
The examples can be browsed on GitHub:
https://github.com/aiortc/aioquic/tree/master/examples
.. automodule:: aioquic.asyncio
Client
------
.. autofunction:: connect
Server
------
.. autofunction:: serve
Common
------
.. autoclass:: QuicConnectionProtocol
:members:

Просмотреть файл

@ -0,0 +1,205 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# aioquic documentation build configuration file, created by
# sphinx-quickstart on Thu Feb 8 17:22:14 2018.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
import sys, os
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..'))
class MockBuffer:
Buffer = None
BufferReadError = None
BufferWriteError = None
class MockCrypto:
AEAD = None
CryptoError = ValueError
HeaderProtection = None
class MockPylsqpack:
Decoder = None
Encoder = None
StreamBlocked = None
sys.modules.update({
"aioquic._buffer": MockBuffer(),
"aioquic._crypto": MockCrypto(),
"pylsqpack": MockPylsqpack(),
})
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.intersphinx',
'sphinx_autodoc_typehints',
'sphinxcontrib.asyncio',
]
intersphinx_mapping = {
'cryptography': ('https://cryptography.io/en/latest', None),
'python': ('https://docs.python.org/3', None),
}
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'aioquic'
copyright = u'2019, Jeremy Lainé'
author = u'Jeremy Lainé'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = ''
# The full version, including alpha/beta/rc tags.
release = ''
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
'description': 'A library for QUIC in Python.',
'github_button': True,
'github_user': 'aiortc',
'github_repo': 'aioquic',
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'about.html',
'navigation.html',
'relations.html',
'searchbox.html',
]
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'aioquicdoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'aioquic.tex', 'aioquic Documentation',
author, 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'aioquic', 'aioquic Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'aioquic', 'aioquic Documentation',
author, 'aioquic', 'One line description of project.',
'Miscellaneous'),
]

Просмотреть файл

@ -0,0 +1,37 @@
Design
======
Sans-IO APIs
............
Both the QUIC and the HTTP/3 APIs follow the `sans I/O`_ pattern, leaving
actual I/O operations to the API user. This approach has a number of
advantages including making the code testable and allowing integration with
different concurrency models.
TLS and encryption
..................
TLS 1.3
+++++++
``aioquic`` features a minimal TLS 1.3 implementation built upon the
`cryptography`_ library. This is because QUIC requires some APIs which are
currently unavailable in mainstream TLS implementations such as OpenSSL:
- the ability to extract traffic secrets
- the ability to operate directly on TLS messages, without using the TLS
record layer
Header protection and payload encryption
++++++++++++++++++++++++++++++++++++++++
QUIC makes extensive use of cryptographic operations to protect QUIC packet
headers and encrypt packet payloads. These operations occur for every single
packet and are a determining factor for performance. For this reason, they
are implemented as a C extension linked to `OpenSSL`_.
.. _sans I/O: https://sans-io.readthedocs.io/
.. _cryptography: https://cryptography.io/
.. _OpenSSL: https://www.openssl.org/

Просмотреть файл

@ -0,0 +1,44 @@
HTTP/3 API
==========
The HTTP/3 API performs no I/O on its own, leaving this to the API user.
This allows you to integrate HTTP/3 in any Python application, regardless of
the concurrency model you are using.
Connection
----------
.. automodule:: aioquic.h3.connection
.. autoclass:: H3Connection
:members:
Events
------
.. automodule:: aioquic.h3.events
.. autoclass:: H3Event
:members:
.. autoclass:: DataReceived
:members:
.. autoclass:: HeadersReceived
:members:
.. autoclass:: PushPromiseReceived
:members:
Exceptions
----------
.. automodule:: aioquic.h3.exceptions
.. autoclass:: H3Error
:members:
.. autoclass:: NoAvailablePushIDError
:members:

Просмотреть файл

@ -0,0 +1,39 @@
aioquic
=======
|pypi-v| |pypi-pyversions| |pypi-l| |tests| |codecov|
.. |pypi-v| image:: https://img.shields.io/pypi/v/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |pypi-pyversions| image:: https://img.shields.io/pypi/pyversions/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |pypi-l| image:: https://img.shields.io/pypi/l/aioquic.svg
:target: https://pypi.python.org/pypi/aioquic
.. |tests| image:: https://github.com/aiortc/aioquic/workflows/tests/badge.svg
:target: https://github.com/aiortc/aioquic/actions
.. |codecov| image:: https://img.shields.io/codecov/c/github/aiortc/aioquic.svg
:target: https://codecov.io/gh/aiortc/aioquic
``aioquic`` is a library for the QUIC network protocol in Python. It features several
APIs:
- a QUIC API following the "bring your own I/O" pattern, suitable for
embedding in any framework,
- an HTTP/3 API which also follows the "bring your own I/O" pattern,
- a QUIC convenience API built on top of :mod:`asyncio`, Python's standard asynchronous
I/O framework.
.. toctree::
:maxdepth: 2
design
quic
h3
asyncio
license

Просмотреть файл

@ -0,0 +1,4 @@
License
-------
.. literalinclude:: ../LICENSE

Просмотреть файл

@ -0,0 +1,51 @@
QUIC API
========
The QUIC API performs no I/O on its own, leaving this to the API user.
This allows you to integrate QUIC in any Python application, regardless of
the concurrency model you are using.
Connection
----------
.. automodule:: aioquic.quic.connection
.. autoclass:: QuicConnection
:members:
Configuration
-------------
.. automodule:: aioquic.quic.configuration
.. autoclass:: QuicConfiguration
:members:
.. automodule:: aioquic.quic.logger
.. autoclass:: QuicLogger
:members:
Events
------
.. automodule:: aioquic.quic.events
.. autoclass:: QuicEvent
:members:
.. autoclass:: ConnectionTerminated
:members:
.. autoclass:: HandshakeCompleted
:members:
.. autoclass:: PingAcknowledged
:members:
.. autoclass:: StreamDataReceived
:members:
.. autoclass:: StreamReset
:members:

Просмотреть файл

@ -0,0 +1,109 @@
#
# demo application for http3_server.py
#
import datetime
import os
from urllib.parse import urlencode
import httpbin
from asgiref.wsgi import WsgiToAsgi
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse, Response
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
from starlette.websockets import WebSocketDisconnect
ROOT = os.path.dirname(__file__)
STATIC_ROOT = os.environ.get("STATIC_ROOT", os.path.join(ROOT, "htdocs"))
STATIC_URL = "/"
LOGS_PATH = os.path.join(STATIC_ROOT, "logs")
QVIS_URL = "https://qvis.edm.uhasselt.be/"
templates = Jinja2Templates(directory=os.path.join(ROOT, "templates"))
app = Starlette()
@app.route("/")
async def homepage(request):
"""
Simple homepage.
"""
await request.send_push_promise("/style.css")
return templates.TemplateResponse("index.html", {"request": request})
@app.route("/echo", methods=["POST"])
async def echo(request):
"""
HTTP echo endpoint.
"""
content = await request.body()
media_type = request.headers.get("content-type")
return Response(content, media_type=media_type)
@app.route("/logs/?")
async def logs(request):
"""
Browsable list of QLOG files.
"""
logs = []
for name in os.listdir(LOGS_PATH):
if name.endswith(".qlog"):
s = os.stat(os.path.join(LOGS_PATH, name))
file_url = "https://" + request.headers["host"] + "/logs/" + name
logs.append(
{
"date": datetime.datetime.utcfromtimestamp(s.st_mtime).strftime(
"%Y-%m-%d %H:%M:%S"
),
"file_url": file_url,
"name": name[:-5],
"qvis_url": QVIS_URL
+ "?"
+ urlencode({"file": file_url})
+ "#/sequence",
"size": s.st_size,
}
)
return templates.TemplateResponse(
"logs.html",
{
"logs": sorted(logs, key=lambda x: x["date"], reverse=True),
"request": request,
},
)
@app.route("/{size:int}")
def padding(request):
"""
Dynamically generated data, maximum 50MB.
"""
size = min(50000000, request.path_params["size"])
return PlainTextResponse("Z" * size)
@app.websocket_route("/ws")
async def ws(websocket):
"""
WebSocket echo endpoint.
"""
if "chat" in websocket.scope["subprotocols"]:
subprotocol = "chat"
else:
subprotocol = None
await websocket.accept(subprotocol=subprotocol)
try:
while True:
message = await websocket.receive_text()
await websocket.send_text(message)
except WebSocketDisconnect:
pass
app.mount("/httpbin", WsgiToAsgi(httpbin.app))
app.mount(STATIC_URL, StaticFiles(directory=STATIC_ROOT, html=True))

Просмотреть файл

@ -0,0 +1,2 @@
User-agent: *
Disallow: /logs

Просмотреть файл

@ -0,0 +1,10 @@
body {
font-family: Arial, sans-serif;
font-size: 16px;
margin: 0 auto;
width: 40em;
}
table.logs {
width: 100%;
}

Просмотреть файл

@ -0,0 +1,455 @@
import argparse
import asyncio
import json
import logging
import os
import pickle
import ssl
import time
from collections import deque
from typing import Callable, Deque, Dict, List, Optional, Union, cast
from urllib.parse import urlparse
import wsproto
import wsproto.events
import aioquic
from aioquic.asyncio.client import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.h0.connection import H0_ALPN, H0Connection
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import (
DataReceived,
H3Event,
HeadersReceived,
PushPromiseReceived,
)
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import QuicEvent
from aioquic.quic.logger import QuicLogger
from aioquic.tls import SessionTicket
try:
import uvloop
except ImportError:
uvloop = None
logger = logging.getLogger("client")
HttpConnection = Union[H0Connection, H3Connection]
USER_AGENT = "aioquic/" + aioquic.__version__
class URL:
def __init__(self, url: str) -> None:
parsed = urlparse(url)
self.authority = parsed.netloc
self.full_path = parsed.path
if parsed.query:
self.full_path += "?" + parsed.query
self.scheme = parsed.scheme
class HttpRequest:
def __init__(
self, method: str, url: URL, content: bytes = b"", headers: Dict = {}
) -> None:
self.content = content
self.headers = headers
self.method = method
self.url = url
class WebSocket:
def __init__(
self, http: HttpConnection, stream_id: int, transmit: Callable[[], None]
) -> None:
self.http = http
self.queue: asyncio.Queue[str] = asyncio.Queue()
self.stream_id = stream_id
self.subprotocol: Optional[str] = None
self.transmit = transmit
self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT)
async def close(self, code=1000, reason="") -> None:
"""
Perform the closing handshake.
"""
data = self.websocket.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
self.http.send_data(stream_id=self.stream_id, data=data, end_stream=True)
self.transmit()
async def recv(self) -> str:
"""
Receive the next message.
"""
return await self.queue.get()
async def send(self, message: str) -> None:
"""
Send a message.
"""
assert isinstance(message, str)
data = self.websocket.send(wsproto.events.TextMessage(data=message))
self.http.send_data(stream_id=self.stream_id, data=data, end_stream=False)
self.transmit()
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, HeadersReceived):
for header, value in event.headers:
if header == b"sec-websocket-protocol":
self.subprotocol = value.decode()
elif isinstance(event, DataReceived):
self.websocket.receive_data(event.data)
for ws_event in self.websocket.events():
self.websocket_event_received(ws_event)
def websocket_event_received(self, event: wsproto.events.Event) -> None:
if isinstance(event, wsproto.events.TextMessage):
self.queue.put_nowait(event.data)
class HttpClient(QuicConnectionProtocol):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.pushes: Dict[int, Deque[H3Event]] = {}
self._http: Optional[HttpConnection] = None
self._request_events: Dict[int, Deque[H3Event]] = {}
self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {}
self._websockets: Dict[int, WebSocket] = {}
if self._quic.configuration.alpn_protocols[0].startswith("hq-"):
self._http = H0Connection(self._quic)
else:
self._http = H3Connection(self._quic)
async def get(self, url: str, headers: Dict = {}) -> Deque[H3Event]:
"""
Perform a GET request.
"""
return await self._request(
HttpRequest(method="GET", url=URL(url), headers=headers)
)
async def post(self, url: str, data: bytes, headers: Dict = {}) -> Deque[H3Event]:
"""
Perform a POST request.
"""
return await self._request(
HttpRequest(method="POST", url=URL(url), content=data, headers=headers)
)
async def websocket(self, url: str, subprotocols: List[str] = []) -> WebSocket:
"""
Open a WebSocket.
"""
request = HttpRequest(method="CONNECT", url=URL(url))
stream_id = self._quic.get_next_available_stream_id()
websocket = WebSocket(
http=self._http, stream_id=stream_id, transmit=self.transmit
)
self._websockets[stream_id] = websocket
headers = [
(b":method", b"CONNECT"),
(b":scheme", b"https"),
(b":authority", request.url.authority.encode()),
(b":path", request.url.full_path.encode()),
(b":protocol", b"websocket"),
(b"user-agent", USER_AGENT.encode()),
(b"sec-websocket-version", b"13"),
]
if subprotocols:
headers.append(
(b"sec-websocket-protocol", ", ".join(subprotocols).encode())
)
self._http.send_headers(stream_id=stream_id, headers=headers)
self.transmit()
return websocket
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, (HeadersReceived, DataReceived)):
stream_id = event.stream_id
if stream_id in self._request_events:
# http
self._request_events[event.stream_id].append(event)
if event.stream_ended:
request_waiter = self._request_waiter.pop(stream_id)
request_waiter.set_result(self._request_events.pop(stream_id))
elif stream_id in self._websockets:
# websocket
websocket = self._websockets[stream_id]
websocket.http_event_received(event)
elif event.push_id in self.pushes:
# push
self.pushes[event.push_id].append(event)
elif isinstance(event, PushPromiseReceived):
self.pushes[event.push_id] = deque()
self.pushes[event.push_id].append(event)
def quic_event_received(self, event: QuicEvent) -> None:
#  pass event to the HTTP layer
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
async def _request(self, request: HttpRequest):
stream_id = self._quic.get_next_available_stream_id()
self._http.send_headers(
stream_id=stream_id,
headers=[
(b":method", request.method.encode()),
(b":scheme", request.url.scheme.encode()),
(b":authority", request.url.authority.encode()),
(b":path", request.url.full_path.encode()),
(b"user-agent", USER_AGENT.encode()),
]
+ [(k.encode(), v.encode()) for (k, v) in request.headers.items()],
)
self._http.send_data(stream_id=stream_id, data=request.content, end_stream=True)
waiter = self._loop.create_future()
self._request_events[stream_id] = deque()
self._request_waiter[stream_id] = waiter
self.transmit()
return await asyncio.shield(waiter)
async def perform_http_request(
client: HttpClient, url: str, data: str, include: bool, output_dir: Optional[str],
) -> None:
# perform request
start = time.time()
if data is not None:
http_events = await client.post(
url,
data=data.encode(),
headers={"content-type": "application/x-www-form-urlencoded"},
)
else:
http_events = await client.get(url)
elapsed = time.time() - start
# print speed
octets = 0
for http_event in http_events:
if isinstance(http_event, DataReceived):
octets += len(http_event.data)
logger.info(
"Received %d bytes in %.1f s (%.3f Mbps)"
% (octets, elapsed, octets * 8 / elapsed / 1000000)
)
# output response
if output_dir is not None:
output_path = os.path.join(
output_dir, os.path.basename(urlparse(url).path) or "index.html"
)
with open(output_path, "wb") as output_file:
for http_event in http_events:
if isinstance(http_event, HeadersReceived) and include:
headers = b""
for k, v in http_event.headers:
headers += k + b": " + v + b"\r\n"
if headers:
output_file.write(headers + b"\r\n")
elif isinstance(http_event, DataReceived):
output_file.write(http_event.data)
def save_session_ticket(ticket: SessionTicket) -> None:
"""
Callback which is invoked by the TLS engine when a new session ticket
is received.
"""
logger.info("New session ticket received")
if args.session_ticket:
with open(args.session_ticket, "wb") as fp:
pickle.dump(ticket, fp)
async def run(
configuration: QuicConfiguration,
urls: List[str],
data: str,
include: bool,
output_dir: Optional[str],
local_port: int,
) -> None:
# parse URL
parsed = urlparse(urls[0])
assert parsed.scheme in (
"https",
"wss",
), "Only https:// or wss:// URLs are supported."
if ":" in parsed.netloc:
host, port_str = parsed.netloc.split(":")
port = int(port_str)
else:
host = parsed.netloc
port = 443
async with connect(
host,
port,
configuration=configuration,
create_protocol=HttpClient,
session_ticket_handler=save_session_ticket,
local_port=local_port,
) as client:
client = cast(HttpClient, client)
if parsed.scheme == "wss":
ws = await client.websocket(urls[0], subprotocols=["chat", "superchat"])
# send some messages and receive reply
for i in range(2):
message = "Hello {}, WebSocket!".format(i)
print("> " + message)
await ws.send(message)
message = await ws.recv()
print("< " + message)
await ws.close()
else:
# perform request
coros = [
perform_http_request(
client=client,
url=url,
data=data,
include=include,
output_dir=output_dir,
)
for url in urls
]
await asyncio.gather(*coros)
if __name__ == "__main__":
defaults = QuicConfiguration(is_client=True)
parser = argparse.ArgumentParser(description="HTTP/3 client")
parser.add_argument(
"url", type=str, nargs="+", help="the URL to query (must be HTTPS)"
)
parser.add_argument(
"--ca-certs", type=str, help="load CA certificates from the specified file"
)
parser.add_argument(
"-d", "--data", type=str, help="send the specified data in a POST request"
)
parser.add_argument(
"-i",
"--include",
action="store_true",
help="include the HTTP response headers in the output",
)
parser.add_argument(
"--max-data",
type=int,
help="connection-wide flow control limit (default: %d)" % defaults.max_data,
)
parser.add_argument(
"--max-stream-data",
type=int,
help="per-stream flow control limit (default: %d)" % defaults.max_stream_data,
)
parser.add_argument(
"-k",
"--insecure",
action="store_true",
help="do not validate server certificate",
)
parser.add_argument("--legacy-http", action="store_true", help="use HTTP/0.9")
parser.add_argument(
"--output-dir", type=str, help="write downloaded files to this directory",
)
parser.add_argument(
"-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format"
)
parser.add_argument(
"-l",
"--secrets-log",
type=str,
help="log secrets to a file, for use with Wireshark",
)
parser.add_argument(
"-s",
"--session-ticket",
type=str,
help="read and write session ticket from the specified file",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
parser.add_argument(
"--local-port", type=int, default=0, help="local port to bind for connections",
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
if args.output_dir is not None and not os.path.isdir(args.output_dir):
raise Exception("%s is not a directory" % args.output_dir)
# prepare configuration
configuration = QuicConfiguration(
is_client=True, alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN
)
if args.ca_certs:
configuration.load_verify_locations(args.ca_certs)
if args.insecure:
configuration.verify_mode = ssl.CERT_NONE
if args.max_data:
configuration.max_data = args.max_data
if args.max_stream_data:
configuration.max_stream_data = args.max_stream_data
if args.quic_log:
configuration.quic_logger = QuicLogger()
if args.secrets_log:
configuration.secrets_log_file = open(args.secrets_log, "a")
if args.session_ticket:
try:
with open(args.session_ticket, "rb") as fp:
configuration.session_ticket = pickle.load(fp)
except FileNotFoundError:
pass
if uvloop is not None:
uvloop.install()
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
run(
configuration=configuration,
urls=args.url,
data=args.data,
include=args.include,
output_dir=args.output_dir,
local_port=args.local_port,
)
)
finally:
if configuration.quic_logger is not None:
with open(args.quic_log, "w") as logger_fp:
json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4)

Просмотреть файл

@ -0,0 +1,484 @@
import argparse
import asyncio
import importlib
import json
import logging
import os
import time
from collections import deque
from email.utils import formatdate
from typing import Callable, Deque, Dict, List, Optional, Union, cast
import wsproto
import wsproto.events
import aioquic
from aioquic.asyncio import QuicConnectionProtocol, serve
from aioquic.h0.connection import H0_ALPN, H0Connection
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import DataReceived, H3Event, HeadersReceived
from aioquic.h3.exceptions import NoAvailablePushIDError
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent
from aioquic.quic.logger import QuicLogger, QuicLoggerTrace
from aioquic.tls import SessionTicket
try:
import uvloop
except ImportError:
uvloop = None
AsgiApplication = Callable
HttpConnection = Union[H0Connection, H3Connection]
SERVER_NAME = "aioquic/" + aioquic.__version__
class HttpRequestHandler:
def __init__(
self,
*,
authority: bytes,
connection: HttpConnection,
protocol: QuicConnectionProtocol,
scope: Dict,
stream_ended: bool,
stream_id: int,
transmit: Callable[[], None],
) -> None:
self.authority = authority
self.connection = connection
self.protocol = protocol
self.queue: asyncio.Queue[Dict] = asyncio.Queue()
self.scope = scope
self.stream_id = stream_id
self.transmit = transmit
if stream_ended:
self.queue.put_nowait({"type": "http.request"})
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, DataReceived):
self.queue.put_nowait(
{
"type": "http.request",
"body": event.data,
"more_body": not event.stream_ended,
}
)
elif isinstance(event, HeadersReceived) and event.stream_ended:
self.queue.put_nowait(
{"type": "http.request", "body": b"", "more_body": False}
)
async def run_asgi(self, app: AsgiApplication) -> None:
await application(self.scope, self.receive, self.send)
async def receive(self) -> Dict:
return await self.queue.get()
async def send(self, message: Dict) -> None:
if message["type"] == "http.response.start":
self.connection.send_headers(
stream_id=self.stream_id,
headers=[
(b":status", str(message["status"]).encode()),
(b"server", SERVER_NAME.encode()),
(b"date", formatdate(time.time(), usegmt=True).encode()),
]
+ [(k, v) for k, v in message["headers"]],
)
elif message["type"] == "http.response.body":
self.connection.send_data(
stream_id=self.stream_id,
data=message.get("body", b""),
end_stream=not message.get("more_body", False),
)
elif message["type"] == "http.response.push" and isinstance(
self.connection, H3Connection
):
request_headers = [
(b":method", b"GET"),
(b":scheme", b"https"),
(b":authority", self.authority),
(b":path", message["path"].encode()),
] + [(k, v) for k, v in message["headers"]]
# send push promise
try:
push_stream_id = self.connection.send_push_promise(
stream_id=self.stream_id, headers=request_headers
)
except NoAvailablePushIDError:
return
# fake request
cast(HttpServerProtocol, self.protocol).http_event_received(
HeadersReceived(
headers=request_headers, stream_ended=True, stream_id=push_stream_id
)
)
self.transmit()
class WebSocketHandler:
def __init__(
self,
*,
connection: HttpConnection,
scope: Dict,
stream_id: int,
transmit: Callable[[], None],
) -> None:
self.closed = False
self.connection = connection
self.http_event_queue: Deque[DataReceived] = deque()
self.queue: asyncio.Queue[Dict] = asyncio.Queue()
self.scope = scope
self.stream_id = stream_id
self.transmit = transmit
self.websocket: Optional[wsproto.Connection] = None
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, DataReceived) and not self.closed:
if self.websocket is not None:
self.websocket.receive_data(event.data)
for ws_event in self.websocket.events():
self.websocket_event_received(ws_event)
else:
# delay event processing until we get `websocket.accept`
# from the ASGI application
self.http_event_queue.append(event)
def websocket_event_received(self, event: wsproto.events.Event) -> None:
if isinstance(event, wsproto.events.TextMessage):
self.queue.put_nowait({"type": "websocket.receive", "text": event.data})
elif isinstance(event, wsproto.events.Message):
self.queue.put_nowait({"type": "websocket.receive", "bytes": event.data})
elif isinstance(event, wsproto.events.CloseConnection):
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
async def run_asgi(self, app: AsgiApplication) -> None:
self.queue.put_nowait({"type": "websocket.connect"})
try:
await application(self.scope, self.receive, self.send)
finally:
if not self.closed:
await self.send({"type": "websocket.close", "code": 1000})
async def receive(self) -> Dict:
return await self.queue.get()
async def send(self, message: Dict) -> None:
data = b""
end_stream = False
if message["type"] == "websocket.accept":
subprotocol = message.get("subprotocol")
self.websocket = wsproto.Connection(wsproto.ConnectionType.SERVER)
headers = [
(b":status", b"200"),
(b"server", SERVER_NAME.encode()),
(b"date", formatdate(time.time(), usegmt=True).encode()),
]
if subprotocol is not None:
headers.append((b"sec-websocket-protocol", subprotocol.encode()))
self.connection.send_headers(stream_id=self.stream_id, headers=headers)
# consume backlog
while self.http_event_queue:
self.http_event_received(self.http_event_queue.popleft())
elif message["type"] == "websocket.close":
if self.websocket is not None:
data = self.websocket.send(
wsproto.events.CloseConnection(code=message["code"])
)
else:
self.connection.send_headers(
stream_id=self.stream_id, headers=[(b":status", b"403")]
)
end_stream = True
elif message["type"] == "websocket.send":
if message.get("text") is not None:
data = self.websocket.send(
wsproto.events.TextMessage(data=message["text"])
)
elif message.get("bytes") is not None:
data = self.websocket.send(
wsproto.events.Message(data=message["bytes"])
)
if data:
self.connection.send_data(
stream_id=self.stream_id, data=data, end_stream=end_stream
)
if end_stream:
self.closed = True
self.transmit()
Handler = Union[HttpRequestHandler, WebSocketHandler]
class HttpServerProtocol(QuicConnectionProtocol):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._handlers: Dict[int, Handler] = {}
self._http: Optional[HttpConnection] = None
def http_event_received(self, event: H3Event) -> None:
if isinstance(event, HeadersReceived) and event.stream_id not in self._handlers:
authority = None
headers = []
http_version = "0.9" if isinstance(self._http, H0Connection) else "3"
raw_path = b""
method = ""
protocol = None
for header, value in event.headers:
if header == b":authority":
authority = value
headers.append((b"host", value))
elif header == b":method":
method = value.decode()
elif header == b":path":
raw_path = value
elif header == b":protocol":
protocol = value.decode()
elif header and not header.startswith(b":"):
headers.append((header, value))
if b"?" in raw_path:
path_bytes, query_string = raw_path.split(b"?", maxsplit=1)
else:
path_bytes, query_string = raw_path, b""
path = path_bytes.decode()
# FIXME: add a public API to retrieve peer address
client_addr = self._http._quic._network_paths[0].addr
client = (client_addr[0], client_addr[1])
handler: Handler
scope: Dict
if method == "CONNECT" and protocol == "websocket":
subprotocols: List[str] = []
for header, value in event.headers:
if header == b"sec-websocket-protocol":
subprotocols = [x.strip() for x in value.decode().split(",")]
scope = {
"client": client,
"headers": headers,
"http_version": http_version,
"method": method,
"path": path,
"query_string": query_string,
"raw_path": raw_path,
"root_path": "",
"scheme": "wss",
"subprotocols": subprotocols,
"type": "websocket",
}
handler = WebSocketHandler(
connection=self._http,
scope=scope,
stream_id=event.stream_id,
transmit=self.transmit,
)
else:
extensions: Dict[str, Dict] = {}
if isinstance(self._http, H3Connection):
extensions["http.response.push"] = {}
scope = {
"client": client,
"extensions": extensions,
"headers": headers,
"http_version": http_version,
"method": method,
"path": path,
"query_string": query_string,
"raw_path": raw_path,
"root_path": "",
"scheme": "https",
"type": "http",
}
handler = HttpRequestHandler(
authority=authority,
connection=self._http,
protocol=self,
scope=scope,
stream_ended=event.stream_ended,
stream_id=event.stream_id,
transmit=self.transmit,
)
self._handlers[event.stream_id] = handler
asyncio.ensure_future(handler.run_asgi(application))
elif (
isinstance(event, (DataReceived, HeadersReceived))
and event.stream_id in self._handlers
):
handler = self._handlers[event.stream_id]
handler.http_event_received(event)
def quic_event_received(self, event: QuicEvent) -> None:
if isinstance(event, ProtocolNegotiated):
if event.alpn_protocol.startswith("h3-"):
self._http = H3Connection(self._quic)
elif event.alpn_protocol.startswith("hq-"):
self._http = H0Connection(self._quic)
elif isinstance(event, DatagramFrameReceived):
if event.data == b"quack":
self._quic.send_datagram_frame(b"quack-ack")
#  pass event to the HTTP layer
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
class SessionTicketStore:
"""
Simple in-memory store for session tickets.
"""
def __init__(self) -> None:
self.tickets: Dict[bytes, SessionTicket] = {}
def add(self, ticket: SessionTicket) -> None:
self.tickets[ticket.ticket] = ticket
def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)
class QuicLoggerCustom(QuicLogger):
"""
Custom QUIC logger which writes one trace per file.
"""
def __init__(self, path: str) -> None:
if not os.path.isdir(path):
raise ValueError("QUIC log output directory '%s' does not exist" % path)
self.path = path
super().__init__()
def end_trace(self, trace: QuicLoggerTrace) -> None:
trace_dict = trace.to_dict()
trace_path = os.path.join(
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
)
with open(trace_path, "w") as logger_fp:
json.dump({"qlog_version": "draft-01", "traces": [trace_dict]}, logger_fp)
self._traces.remove(trace)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="QUIC server")
parser.add_argument(
"app",
type=str,
nargs="?",
default="demo:app",
help="the ASGI application as <module>:<attribute>",
)
parser.add_argument(
"-c",
"--certificate",
type=str,
required=True,
help="load the TLS certificate from the specified file",
)
parser.add_argument(
"--host",
type=str,
default="::",
help="listen on the specified address (defaults to ::)",
)
parser.add_argument(
"--port",
type=int,
default=4433,
help="listen on the specified port (defaults to 4433)",
)
parser.add_argument(
"-k",
"--private-key",
type=str,
required=True,
help="load the TLS private key from the specified file",
)
parser.add_argument(
"-l",
"--secrets-log",
type=str,
help="log secrets to a file, for use with Wireshark",
)
parser.add_argument(
"-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format"
)
parser.add_argument(
"-r",
"--stateless-retry",
action="store_true",
help="send a stateless retry for new connections",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
# import ASGI application
module_str, attr_str = args.app.split(":", maxsplit=1)
module = importlib.import_module(module_str)
application = getattr(module, attr_str)
# create QUIC logger
if args.quic_log:
quic_logger = QuicLoggerCustom(args.quic_log)
else:
quic_logger = None
# open SSL log file
if args.secrets_log:
secrets_log_file = open(args.secrets_log, "a")
else:
secrets_log_file = None
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
is_client=False,
max_datagram_frame_size=65536,
quic_logger=quic_logger,
secrets_log_file=secrets_log_file,
)
# load SSL certificate and key
configuration.load_cert_chain(args.certificate, args.private_key)
ticket_store = SessionTicketStore()
if uvloop is not None:
uvloop.install()
loop = asyncio.get_event_loop()
loop.run_until_complete(
serve(
args.host,
args.port,
configuration=configuration,
create_protocol=HttpServerProtocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
stateless_retry=args.stateless_retry,
)
)
try:
loop.run_forever()
except KeyboardInterrupt:
pass

Просмотреть файл

@ -0,0 +1,213 @@
import argparse
import asyncio
import json
import logging
import pickle
import sys
import time
from collections import deque
from typing import Deque, Dict, cast
from urllib.parse import urlparse
from httpx import AsyncClient
from httpx.config import Timeout
from httpx.dispatch.base import AsyncDispatcher
from httpx.models import Request, Response
from aioquic.asyncio.client import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import DataReceived, H3Event, HeadersReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import QuicEvent
from aioquic.quic.logger import QuicLogger
logger = logging.getLogger("client")
class H3Dispatcher(QuicConnectionProtocol, AsyncDispatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._http = H3Connection(self._quic)
self._request_events: Dict[int, Deque[H3Event]] = {}
self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {}
async def send(self, request: Request, timeout: Timeout = None) -> Response:
stream_id = self._quic.get_next_available_stream_id()
# prepare request
self._http.send_headers(
stream_id=stream_id,
headers=[
(b":method", request.method.encode()),
(b":scheme", request.url.scheme.encode()),
(b":authority", str(request.url.authority).encode()),
(b":path", request.url.full_path.encode()),
]
+ [
(k.encode(), v.encode())
for (k, v) in request.headers.items()
if k not in ("connection", "host")
],
)
self._http.send_data(stream_id=stream_id, data=request.read(), end_stream=True)
# transmit request
waiter = self._loop.create_future()
self._request_events[stream_id] = deque()
self._request_waiter[stream_id] = waiter
self.transmit()
# process response
events: Deque[H3Event] = await asyncio.shield(waiter)
content = b""
headers = []
status_code = None
for event in events:
if isinstance(event, HeadersReceived):
for header, value in event.headers:
if header == b":status":
status_code = int(value.decode())
elif header[0:1] != b":":
headers.append((header.decode(), value.decode()))
elif isinstance(event, DataReceived):
content += event.data
return Response(
status_code=status_code,
http_version="HTTP/3",
headers=headers,
content=content,
request=request,
)
def http_event_received(self, event: H3Event):
if isinstance(event, (HeadersReceived, DataReceived)):
stream_id = event.stream_id
if stream_id in self._request_events:
self._request_events[event.stream_id].append(event)
if event.stream_ended:
request_waiter = self._request_waiter.pop(stream_id)
request_waiter.set_result(self._request_events.pop(stream_id))
def quic_event_received(self, event: QuicEvent):
#  pass event to the HTTP layer
if self._http is not None:
for http_event in self._http.handle_event(event):
self.http_event_received(http_event)
def save_session_ticket(ticket):
"""
Callback which is invoked by the TLS engine when a new session ticket
is received.
"""
logger.info("New session ticket received")
if args.session_ticket:
with open(args.session_ticket, "wb") as fp:
pickle.dump(ticket, fp)
async def run(configuration: QuicConfiguration, url: str, data: str) -> None:
# parse URL
parsed = urlparse(url)
assert parsed.scheme == "https", "Only https:// URLs are supported."
if ":" in parsed.netloc:
host, port_str = parsed.netloc.split(":")
port = int(port_str)
else:
host = parsed.netloc
port = 443
async with connect(
host,
port,
configuration=configuration,
create_protocol=H3Dispatcher,
session_ticket_handler=save_session_ticket,
) as dispatch:
client = AsyncClient(dispatch=cast(AsyncDispatcher, dispatch))
# perform request
start = time.time()
if data is not None:
response = await client.post(
url,
data=data.encode(),
headers={"content-type": "application/x-www-form-urlencoded"},
)
else:
response = await client.get(url)
elapsed = time.time() - start
# print speed
octets = len(response.content)
logger.info(
"Received %d bytes in %.1f s (%.3f Mbps)"
% (octets, elapsed, octets * 8 / elapsed / 1000000)
)
# print response
for header, value in response.headers.items():
sys.stderr.write(header + ": " + value + "\r\n")
sys.stderr.write("\r\n")
sys.stdout.buffer.write(response.content)
sys.stdout.buffer.flush()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HTTP/3 client")
parser.add_argument("url", type=str, help="the URL to query (must be HTTPS)")
parser.add_argument(
"-d", "--data", type=str, help="send the specified data in a POST request"
)
parser.add_argument(
"-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format"
)
parser.add_argument(
"-l",
"--secrets-log",
type=str,
help="log secrets to a file, for use with Wireshark",
)
parser.add_argument(
"-s",
"--session-ticket",
type=str,
help="read and write session ticket from the specified file",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
# prepare configuration
configuration = QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN)
if args.quic_log:
configuration.quic_logger = QuicLogger()
if args.secrets_log:
configuration.secrets_log_file = open(args.secrets_log, "a")
if args.session_ticket:
try:
with open(args.session_ticket, "rb") as fp:
configuration.session_ticket = pickle.load(fp)
except FileNotFoundError:
pass
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
run(configuration=configuration, url=args.url, data=args.data)
)
finally:
if configuration.quic_logger is not None:
with open(args.quic_log, "w") as logger_fp:
json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4)

Просмотреть файл

@ -0,0 +1,532 @@
#
# !!! WARNING !!!
#
# This example uses some private APIs.
#
import argparse
import asyncio
import json
import logging
import ssl
import time
from dataclasses import dataclass, field
from enum import Flag
from typing import Optional, cast
import requests
import urllib3
from http3_client import HttpClient
from aioquic.asyncio import connect
from aioquic.h0.connection import H0_ALPN
from aioquic.h3.connection import H3_ALPN, H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived, PushPromiseReceived
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.logger import QuicLogger
class Result(Flag):
V = 0x000001
H = 0x000002
D = 0x000004
C = 0x000008
R = 0x000010
Z = 0x000020
S = 0x000040
Q = 0x000080
M = 0x000100
B = 0x000200
A = 0x000400
U = 0x000800
P = 0x001000
E = 0x002000
L = 0x004000
T = 0x008000
three = 0x010000
d = 0x020000
p = 0x040000
def __str__(self):
flags = sorted(
map(
lambda x: getattr(Result, x),
filter(lambda x: not x.startswith("_"), dir(Result)),
),
key=lambda x: x.value,
)
result_str = ""
for flag in flags:
if self & flag:
result_str += flag.name
else:
result_str += "-"
return result_str
@dataclass
class Server:
name: str
host: str
port: int = 4433
http3: bool = True
retry_port: Optional[int] = 4434
path: str = "/"
push_path: Optional[str] = None
result: Result = field(default_factory=lambda: Result(0))
session_resumption_port: Optional[int] = None
structured_logging: bool = False
throughput_file_suffix: str = ""
verify_mode: Optional[int] = None
SERVERS = [
Server("akamaiquic", "ietf.akaquic.com", port=443, verify_mode=ssl.CERT_NONE),
Server(
"aioquic", "quic.aiortc.org", port=443, push_path="/", structured_logging=True
),
Server("ats", "quic.ogre.com"),
Server("f5", "f5quic.com", retry_port=4433),
Server("haskell", "mew.org", retry_port=4433),
Server("gquic", "quic.rocks", retry_port=None),
Server("lsquic", "http3-test.litespeedtech.com", push_path="/200?push=/100"),
Server(
"msquic",
"quic.westus.cloudapp.azure.com",
port=4433,
session_resumption_port=4433,
structured_logging=True,
throughput_file_suffix=".txt",
verify_mode=ssl.CERT_NONE,
),
Server(
"mvfst", "fb.mvfst.net", port=443, push_path="/push", structured_logging=True
),
Server("ngtcp2", "nghttp2.org", push_path="/?push=/100"),
Server("ngx_quic", "cloudflare-quic.com", port=443, retry_port=443),
Server("pandora", "pandora.cm.in.tum.de", verify_mode=ssl.CERT_NONE),
Server("picoquic", "test.privateoctopus.com", structured_logging=True),
Server("quant", "quant.eggert.org", http3=False),
Server("quic-go", "quic.seemann.io", port=443, retry_port=443),
Server("quiche", "quic.tech", port=8443, retry_port=8444),
Server("quicly", "quic.examp1e.net"),
Server("quinn", "ralith.com"),
]
async def test_version_negotiation(server: Server, configuration: QuicConfiguration):
# force version negotiation
configuration.supported_versions.insert(0, 0x1A2A3A4A)
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
await protocol.ping()
# check log
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if (
category == "transport"
and event == "packet_received"
and data["packet_type"] == "version_negotiation"
):
server.result |= Result.V
async def test_handshake_and_close(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
await protocol.ping()
server.result |= Result.H
server.result |= Result.C
async def test_stateless_retry(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.retry_port, configuration=configuration
) as protocol:
await protocol.ping()
# check log
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if (
category == "transport"
and event == "packet_received"
and data["packet_type"] == "retry"
):
server.result |= Result.S
async def test_quantum_readiness(server: Server, configuration: QuicConfiguration):
configuration.quantum_readiness_test = True
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
await protocol.ping()
server.result |= Result.Q
async def test_http_0(server: Server, configuration: QuicConfiguration):
if server.path is None:
return
configuration.alpn_protocols = H0_ALPN
async with connect(
server.host,
server.port,
configuration=configuration,
create_protocol=HttpClient,
) as protocol:
protocol = cast(HttpClient, protocol)
# perform HTTP request
events = await protocol.get(
"https://{}:{}{}".format(server.host, server.port, server.path)
)
if events and isinstance(events[0], HeadersReceived):
server.result |= Result.D
async def test_http_3(server: Server, configuration: QuicConfiguration):
if server.path is None:
return
configuration.alpn_protocols = H3_ALPN
async with connect(
server.host,
server.port,
configuration=configuration,
create_protocol=HttpClient,
) as protocol:
protocol = cast(HttpClient, protocol)
# perform HTTP request
events = await protocol.get(
"https://{}:{}{}".format(server.host, server.port, server.path)
)
if events and isinstance(events[0], HeadersReceived):
server.result |= Result.D
server.result |= Result.three
# perform more HTTP requests to use QPACK dynamic tables
for i in range(2):
events = await protocol.get(
"https://{}:{}{}".format(server.host, server.port, server.path)
)
if events and isinstance(events[0], HeadersReceived):
http = cast(H3Connection, protocol._http)
protocol._quic._logger.info(
"QPACK decoder bytes RX %d TX %d",
http._decoder_bytes_received,
http._decoder_bytes_sent,
)
protocol._quic._logger.info(
"QPACK encoder bytes RX %d TX %d",
http._encoder_bytes_received,
http._encoder_bytes_sent,
)
if (
http._decoder_bytes_received
and http._decoder_bytes_sent
and http._encoder_bytes_received
and http._encoder_bytes_sent
):
server.result |= Result.d
# check push support
if server.push_path is not None:
protocol.pushes.clear()
await protocol.get(
"https://{}:{}{}".format(server.host, server.port, server.push_path)
)
await asyncio.sleep(0.5)
for push_id, events in protocol.pushes.items():
if (
len(events) >= 3
and isinstance(events[0], PushPromiseReceived)
and isinstance(events[1], HeadersReceived)
and isinstance(events[2], DataReceived)
):
protocol._quic._logger.info(
"Push promise %d for %s received (status %s)",
push_id,
dict(events[0].headers)[b":path"].decode("ascii"),
int(dict(events[1].headers)[b":status"]),
)
server.result |= Result.p
async def test_session_resumption(server: Server, configuration: QuicConfiguration):
port = server.session_resumption_port or server.port
saved_ticket = None
def session_ticket_handler(ticket):
nonlocal saved_ticket
saved_ticket = ticket
# connect a first time, receive a ticket
async with connect(
server.host,
port,
configuration=configuration,
session_ticket_handler=session_ticket_handler,
) as protocol:
await protocol.ping()
# some servers don't send the ticket immediately
await asyncio.sleep(1)
# connect a second time, with the ticket
if saved_ticket is not None:
configuration.session_ticket = saved_ticket
async with connect(server.host, port, configuration=configuration) as protocol:
await protocol.ping()
# check session was resumed
if protocol._quic.tls.session_resumed:
server.result |= Result.R
# check early data was accepted
if protocol._quic.tls.early_data_accepted:
server.result |= Result.Z
async def test_key_update(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
# cause some traffic
await protocol.ping()
# request key update
protocol.request_key_update()
# cause more traffic
await protocol.ping()
server.result |= Result.U
async def test_migration(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
# cause some traffic
await protocol.ping()
# change connection ID and replace transport
protocol.change_connection_id()
protocol._transport.close()
await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
# cause more traffic
await protocol.ping()
# check log
dcids = set()
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if (
category == "transport"
and event == "packet_received"
and data["packet_type"] == "1RTT"
):
dcids.add(data["header"]["dcid"])
if len(dcids) == 2:
server.result |= Result.M
async def test_rebinding(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
# cause some traffic
await protocol.ping()
# replace transport
protocol._transport.close()
await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))
# cause more traffic
await protocol.ping()
server.result |= Result.B
async def test_spin_bit(server: Server, configuration: QuicConfiguration):
async with connect(
server.host, server.port, configuration=configuration
) as protocol:
for i in range(5):
await protocol.ping()
# check log
spin_bits = set()
for stamp, category, event, data in configuration.quic_logger.to_dict()[
"traces"
][0]["events"]:
if category == "connectivity" and event == "spin_bit_updated":
spin_bits.add(data["state"])
if len(spin_bits) == 2:
server.result |= Result.P
async def test_throughput(server: Server, configuration: QuicConfiguration):
failures = 0
for size in [5000000, 10000000]:
print("Testing %d bytes download" % size)
path = "/%d%s" % (size, server.throughput_file_suffix)
# perform HTTP request over TCP
start = time.time()
response = requests.get("https://" + server.host + path, verify=False)
tcp_octets = len(response.content)
tcp_elapsed = time.time() - start
assert tcp_octets == size, "HTTP/TCP response size mismatch"
# perform HTTP request over QUIC
if server.http3:
configuration.alpn_protocols = H3_ALPN
else:
configuration.alpn_protocols = H0_ALPN
start = time.time()
async with connect(
server.host,
server.port,
configuration=configuration,
create_protocol=HttpClient,
) as protocol:
protocol = cast(HttpClient, protocol)
http_events = await protocol.get(
"https://{}:{}{}".format(server.host, server.port, path)
)
quic_elapsed = time.time() - start
quic_octets = 0
for http_event in http_events:
if isinstance(http_event, DataReceived):
quic_octets += len(http_event.data)
assert quic_octets == size, "HTTP/QUIC response size mismatch"
print(" - HTTP/TCP completed in %.3f s" % tcp_elapsed)
print(" - HTTP/QUIC completed in %.3f s" % quic_elapsed)
if quic_elapsed > 1.1 * tcp_elapsed:
failures += 1
print(" => FAIL")
else:
print(" => PASS")
if failures == 0:
server.result |= Result.T
def print_result(server: Server) -> None:
result = str(server.result).replace("three", "3")
result = result[0:8] + " " + result[8:16] + " " + result[16:]
print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result))
async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None:
for server in servers:
if server.structured_logging:
server.result |= Result.L
for test_name, test_func in tests:
print("\n=== %s %s ===\n" % (server.name, test_name))
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN + H0_ALPN,
is_client=True,
quic_logger=QuicLogger(),
secrets_log_file=secrets_log_file,
verify_mode=server.verify_mode,
)
if test_name == "test_throughput":
timeout = 60
else:
timeout = 10
try:
await asyncio.wait_for(
test_func(server, configuration), timeout=timeout
)
except Exception as exc:
print(exc)
if quic_log:
with open("%s-%s.qlog" % (server.name, test_name), "w") as logger_fp:
json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4)
print("")
print_result(server)
# print summary
if len(servers) > 1:
print("SUMMARY")
for server in servers:
print_result(server)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="QUIC interop client")
parser.add_argument(
"-q",
"--quic-log",
action="store_true",
help="log QUIC events to a file in QLOG format",
)
parser.add_argument(
"--server", type=str, help="only run against the specified server."
)
parser.add_argument("--test", type=str, help="only run the specifed test.")
parser.add_argument(
"-l",
"--secrets-log",
type=str,
help="log secrets to a file, for use with Wireshark",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
# open SSL log file
if args.secrets_log:
secrets_log_file = open(args.secrets_log, "a")
else:
secrets_log_file = None
# determine what to run
servers = SERVERS
tests = list(filter(lambda x: x[0].startswith("test_"), globals().items()))
if args.server:
servers = list(filter(lambda x: x.name == args.server, servers))
if args.test:
tests = list(filter(lambda x: x[0] == args.test, tests))
# disable requests SSL warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
servers=servers,
tests=tests,
quic_log=args.quic_log,
secrets_log_file=secrets_log_file,
)
)

Просмотреть файл

@ -0,0 +1,100 @@
import argparse
import asyncio
import json
import logging
import ssl
from typing import Optional, cast
from aioquic.asyncio.client import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.events import DatagramFrameReceived, QuicEvent
from aioquic.quic.logger import QuicLogger
logger = logging.getLogger("client")
class SiduckClient(QuicConnectionProtocol):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ack_waiter: Optional[asyncio.Future[None]] = None
async def quack(self) -> None:
assert self._ack_waiter is None, "Only one quack at a time."
self._quic.send_datagram_frame(b"quack")
waiter = self._loop.create_future()
self._ack_waiter = waiter
self.transmit()
return await asyncio.shield(waiter)
def quic_event_received(self, event: QuicEvent) -> None:
if self._ack_waiter is not None:
if isinstance(event, DatagramFrameReceived) and event.data == b"quack-ack":
waiter = self._ack_waiter
self._ack_waiter = None
waiter.set_result(None)
async def run(configuration: QuicConfiguration, host: str, port: int) -> None:
async with connect(
host, port, configuration=configuration, create_protocol=SiduckClient
) as client:
client = cast(SiduckClient, client)
logger.info("sending quack")
await client.quack()
logger.info("received quack-ack")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SiDUCK client")
parser.add_argument(
"host", type=str, help="The remote peer's host name or IP address"
)
parser.add_argument("port", type=int, help="The remote peer's port number")
parser.add_argument(
"-k",
"--insecure",
action="store_true",
help="do not validate server certificate",
)
parser.add_argument(
"-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format"
)
parser.add_argument(
"-l",
"--secrets-log",
type=str,
help="log secrets to a file, for use with Wireshark",
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
args = parser.parse_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
configuration = QuicConfiguration(
alpn_protocols=["siduck"], is_client=True, max_datagram_frame_size=65536
)
if args.insecure:
configuration.verify_mode = ssl.CERT_NONE
if args.quic_log:
configuration.quic_logger = QuicLogger()
if args.secrets_log:
configuration.secrets_log_file = open(args.secrets_log, "a")
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
run(configuration=configuration, host=args.host, port=args.port)
)
finally:
if configuration.quic_logger is not None:
with open(args.quic_log, "w") as logger_fp:
json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4)

Просмотреть файл

@ -0,0 +1,31 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8"/>
<title>aioquic</title>
<link rel="stylesheet" href="/style.css"/>
</head>
<body>
<h1>Welcome to aioquic</h1>
<p>
This is a test page for <a href="https://github.com/aiortc/aioquic/">aioquic</a>,
a QUIC and HTTP/3 implementation written in Python.
</p>
{% if request.scope["http_version"] == "3" %}
<p>
Congratulations, you loaded this page using HTTP/3!
</p>
{% endif %}
<h2>Available endpoints</h2>
<ul>
<li><strong>GET /</strong> returns the homepage</li>
<li><strong>GET /NNNNN</strong> returns NNNNN bytes of plain text</li>
<li><strong>POST /echo</strong> returns the request data</li>
<li>
<strong>CONNECT /ws</strong> runs a WebSocket echo service.
You must set the <em>:protocol</em> pseudo-header to <em>"websocket"</em>.
</li>
<li>There is also an <a href="/httpbin/">httpbin instance</a>.</li>
</ul>
</body>
</html>

Просмотреть файл

@ -0,0 +1,28 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8"/>
<title>aioquic - logs</title>
<link rel="stylesheet" href="/style.css"/>
</head>
<body>
<h1>QLOG files</h1>
<table class="logs">
<tr>
<th>name</th>
<th>date (UTC)</th>
<th>size</th>
</tr>
{% for log in logs %}
<tr>
<td class="name">
<a href="{{ log.file_url }}">{{ log.name }}</a>
<a href="{{ log.qvis_url }}">[qvis]</a>
</td>
<td class="date">{{ log.date }}</td>
<td class="size">{{ log.size }}</td>
</tr>
{% endfor %}
</table>
</body>
</html>

Просмотреть файл

@ -0,0 +1,3 @@
cryptography
sphinx_autodoc_typehints
sphinxcontrib-asyncio

Просмотреть файл

@ -0,0 +1,29 @@
#!/bin/sh
set -e
destdir=$1
cachedir=$1.$PYTHON_ARCH
for d in openssl $destdir; do
if [ -e $d ]; then
rm -rf $d
fi
done
if [ ! -e $cachedir ]; then
# build openssl
mkdir openssl
curl -L https://www.openssl.org/source/openssl-1.1.1f.tar.gz | tar xz -C openssl --strip-components 1
cd openssl
./config no-comp no-shared no-tests
make
mkdir $cachedir
mkdir $cachedir/lib
cp -R include $cachedir
cp libcrypto.a libssl.a $cachedir/lib
fi
cp -R $cachedir $destdir

Просмотреть файл

@ -0,0 +1,39 @@
set destdir=%1
set cachedir=%1.%PYTHON_ARCH%
for %%d in (openssl %destdir%) do (
if exist %%d (
rmdir /s /q %%d
)
)
if %PYTHON_ARCH% == 64 (
set OPENSSL_CONFIG=VC-WIN64A
set VC_ARCH=x64
) else (
set OPENSSL_CONFIG=VC-WIN32
set VC_ARCH=x86
)
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" %VC_ARCH%
SET PATH=%PATH%;C:\Program Files\NASM
if not exist %cachedir% (
mkdir openssl
curl -L https://www.openssl.org/source/openssl-1.1.1f.tar.gz -o openssl.tar.gz
tar xzf openssl.tar.gz -C openssl --strip-components 1
cd openssl
perl Configure no-comp no-shared no-tests %OPENSSL_CONFIG%
nmake
mkdir %cachedir%
mkdir %cachedir%\include
mkdir %cachedir%\lib
xcopy include %cachedir%\include\ /E
copy libcrypto.lib %cachedir%\lib\
copy libssl.lib %cachedir%\lib\
)
mkdir %destdir%
xcopy %cachedir% %destdir% /E

21
testing/web-platform/tests/tools/third_party/aioquic/setup.cfg поставляемый Normal file
Просмотреть файл

@ -0,0 +1,21 @@
[coverage:run]
source = aioquic
[flake8]
ignore=E203,W503
max-line-length=150
[isort]
default_section = THIRDPARTY
include_trailing_comma = True
known_first_party = aioquic
line_length = 88
multi_line_output = 3
[mypy]
disallow_untyped_calls = True
disallow_untyped_decorators = True
ignore_missing_imports = True
strict_optional = False
warn_redundant_casts = True
warn_unused_ignores = True

69
testing/web-platform/tests/tools/third_party/aioquic/setup.py поставляемый Normal file
Просмотреть файл

@ -0,0 +1,69 @@
import os.path
import sys
import setuptools
root_dir = os.path.abspath(os.path.dirname(__file__))
about = {}
about_file = os.path.join(root_dir, "src", "aioquic", "about.py")
with open(about_file, encoding="utf-8") as fp:
exec(fp.read(), about)
readme_file = os.path.join(root_dir, "README.rst")
with open(readme_file, encoding="utf-8") as f:
long_description = f.read()
if sys.platform == "win32":
extra_compile_args = []
libraries = ["libcrypto", "advapi32", "crypt32", "gdi32", "user32", "ws2_32"]
else:
extra_compile_args = ["-std=c99"]
libraries = ["crypto"]
setuptools.setup(
name=about["__title__"],
version=about["__version__"],
description=about["__summary__"],
long_description=long_description,
url=about["__uri__"],
author=about["__author__"],
author_email=about["__email__"],
license=about["__license__"],
include_package_data=True,
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Internet :: WWW/HTTP",
],
ext_modules=[
setuptools.Extension(
"aioquic._buffer",
extra_compile_args=extra_compile_args,
sources=["src/aioquic/_buffer.c"],
),
setuptools.Extension(
"aioquic._crypto",
extra_compile_args=extra_compile_args,
libraries=libraries,
sources=["src/aioquic/_crypto.c"],
),
],
package_dir={"": "src"},
package_data={"aioquic": ["py.typed", "_buffer.pyi", "_crypto.pyi"]},
packages=["aioquic", "aioquic.asyncio", "aioquic.h0", "aioquic.h3", "aioquic.quic"],
install_requires=[
"certifi",
"cryptography >= 2.5",
'dataclasses; python_version < "3.7"',
"pylsqpack >= 0.3.3, < 0.4.0",
],
)

Просмотреть файл

@ -0,0 +1,3 @@
# flake8: noqa
from .about import __version__

Просмотреть файл

@ -0,0 +1,440 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <stdint.h>
#define MODULE_NAME "aioquic._buffer"
static PyObject *BufferReadError;
static PyObject *BufferWriteError;
typedef struct {
PyObject_HEAD
uint8_t *base;
uint8_t *end;
uint8_t *pos;
} BufferObject;
#define CHECK_READ_BOUNDS(self, len) \
if (len < 0 || self->pos + len > self->end) { \
PyErr_SetString(BufferReadError, "Read out of bounds"); \
return NULL; \
}
#define CHECK_WRITE_BOUNDS(self, len) \
if (self->pos + len > self->end) { \
PyErr_SetString(BufferWriteError, "Write out of bounds"); \
return NULL; \
}
static int
Buffer_init(BufferObject *self, PyObject *args, PyObject *kwargs)
{
const char *kwlist[] = {"capacity", "data", NULL};
int capacity = 0;
const unsigned char *data = NULL;
Py_ssize_t data_len = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iy#", (char**)kwlist, &capacity, &data, &data_len))
return -1;
if (data != NULL) {
self->base = malloc(data_len);
self->end = self->base + data_len;
memcpy(self->base, data, data_len);
} else {
self->base = malloc(capacity);
self->end = self->base + capacity;
}
self->pos = self->base;
return 0;
}
static void
Buffer_dealloc(BufferObject *self)
{
free(self->base);
}
static PyObject *
Buffer_data_slice(BufferObject *self, PyObject *args)
{
int start, stop;
if (!PyArg_ParseTuple(args, "ii", &start, &stop))
return NULL;
if (start < 0 || self->base + start > self->end ||
stop < 0 || self->base + stop > self->end ||
stop < start) {
PyErr_SetString(BufferReadError, "Read out of bounds");
return NULL;
}
return PyBytes_FromStringAndSize((const char*)(self->base + start), (stop - start));
}
static PyObject *
Buffer_eof(BufferObject *self, PyObject *args)
{
if (self->pos == self->end)
Py_RETURN_TRUE;
Py_RETURN_FALSE;
}
static PyObject *
Buffer_pull_bytes(BufferObject *self, PyObject *args)
{
int len;
if (!PyArg_ParseTuple(args, "i", &len))
return NULL;
CHECK_READ_BOUNDS(self, len);
PyObject *o = PyBytes_FromStringAndSize((const char*)self->pos, len);
self->pos += len;
return o;
}
static PyObject *
Buffer_pull_uint8(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 1)
return PyLong_FromUnsignedLong(
(uint8_t)(*(self->pos++))
);
}
static PyObject *
Buffer_pull_uint16(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 2)
uint16_t value = (uint16_t)(*(self->pos)) << 8 |
(uint16_t)(*(self->pos + 1));
self->pos += 2;
return PyLong_FromUnsignedLong(value);
}
static PyObject *
Buffer_pull_uint32(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 4)
uint32_t value = (uint32_t)(*(self->pos)) << 24 |
(uint32_t)(*(self->pos + 1)) << 16 |
(uint32_t)(*(self->pos + 2)) << 8 |
(uint32_t)(*(self->pos + 3));
self->pos += 4;
return PyLong_FromUnsignedLong(value);
}
static PyObject *
Buffer_pull_uint64(BufferObject *self, PyObject *args)
{
CHECK_READ_BOUNDS(self, 8)
uint64_t value = (uint64_t)(*(self->pos)) << 56 |
(uint64_t)(*(self->pos + 1)) << 48 |
(uint64_t)(*(self->pos + 2)) << 40 |
(uint64_t)(*(self->pos + 3)) << 32 |
(uint64_t)(*(self->pos + 4)) << 24 |
(uint64_t)(*(self->pos + 5)) << 16 |
(uint64_t)(*(self->pos + 6)) << 8 |
(uint64_t)(*(self->pos + 7));
self->pos += 8;
return PyLong_FromUnsignedLongLong(value);
}
static PyObject *
Buffer_pull_uint_var(BufferObject *self, PyObject *args)
{
uint64_t value;
CHECK_READ_BOUNDS(self, 1)
switch (*(self->pos) >> 6) {
case 0:
value = *(self->pos++) & 0x3F;
break;
case 1:
CHECK_READ_BOUNDS(self, 2)
value = (uint16_t)(*(self->pos) & 0x3F) << 8 |
(uint16_t)(*(self->pos + 1));
self->pos += 2;
break;
case 2:
CHECK_READ_BOUNDS(self, 4)
value = (uint32_t)(*(self->pos) & 0x3F) << 24 |
(uint32_t)(*(self->pos + 1)) << 16 |
(uint32_t)(*(self->pos + 2)) << 8 |
(uint32_t)(*(self->pos + 3));
self->pos += 4;
break;
default:
CHECK_READ_BOUNDS(self, 8)
value = (uint64_t)(*(self->pos) & 0x3F) << 56 |
(uint64_t)(*(self->pos + 1)) << 48 |
(uint64_t)(*(self->pos + 2)) << 40 |
(uint64_t)(*(self->pos + 3)) << 32 |
(uint64_t)(*(self->pos + 4)) << 24 |
(uint64_t)(*(self->pos + 5)) << 16 |
(uint64_t)(*(self->pos + 6)) << 8 |
(uint64_t)(*(self->pos + 7));
self->pos += 8;
break;
}
return PyLong_FromUnsignedLongLong(value);
}
static PyObject *
Buffer_push_bytes(BufferObject *self, PyObject *args)
{
const unsigned char *data;
Py_ssize_t data_len;
if (!PyArg_ParseTuple(args, "y#", &data, &data_len))
return NULL;
CHECK_WRITE_BOUNDS(self, data_len)
memcpy(self->pos, data, data_len);
self->pos += data_len;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint8(BufferObject *self, PyObject *args)
{
uint8_t value;
if (!PyArg_ParseTuple(args, "B", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 1)
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint16(BufferObject *self, PyObject *args)
{
uint16_t value;
if (!PyArg_ParseTuple(args, "H", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 2)
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint32(BufferObject *self, PyObject *args)
{
uint32_t value;
if (!PyArg_ParseTuple(args, "I", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 4)
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint64(BufferObject *self, PyObject *args)
{
uint64_t value;
if (!PyArg_ParseTuple(args, "K", &value))
return NULL;
CHECK_WRITE_BOUNDS(self, 8)
*(self->pos++) = (value >> 56);
*(self->pos++) = (value >> 48);
*(self->pos++) = (value >> 40);
*(self->pos++) = (value >> 32);
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
}
static PyObject *
Buffer_push_uint_var(BufferObject *self, PyObject *args)
{
uint64_t value;
if (!PyArg_ParseTuple(args, "K", &value))
return NULL;
if (value <= 0x3F) {
CHECK_WRITE_BOUNDS(self, 1)
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFF) {
CHECK_WRITE_BOUNDS(self, 2)
*(self->pos++) = (value >> 8) | 0x40;
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFFFFFF) {
CHECK_WRITE_BOUNDS(self, 4)
*(self->pos++) = (value >> 24) | 0x80;
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
} else if (value <= 0x3FFFFFFFFFFFFFFF) {
CHECK_WRITE_BOUNDS(self, 8)
*(self->pos++) = (value >> 56) | 0xC0;
*(self->pos++) = (value >> 48);
*(self->pos++) = (value >> 40);
*(self->pos++) = (value >> 32);
*(self->pos++) = (value >> 24);
*(self->pos++) = (value >> 16);
*(self->pos++) = (value >> 8);
*(self->pos++) = value;
Py_RETURN_NONE;
} else {
PyErr_SetString(PyExc_ValueError, "Integer is too big for a variable-length integer");
return NULL;
}
}
static PyObject *
Buffer_seek(BufferObject *self, PyObject *args)
{
int pos;
if (!PyArg_ParseTuple(args, "i", &pos))
return NULL;
if (pos < 0 || self->base + pos > self->end) {
PyErr_SetString(BufferReadError, "Seek out of bounds");
return NULL;
}
self->pos = self->base + pos;
Py_RETURN_NONE;
}
static PyObject *
Buffer_tell(BufferObject *self, PyObject *args)
{
return PyLong_FromSsize_t(self->pos - self->base);
}
static PyMethodDef Buffer_methods[] = {
{"data_slice", (PyCFunction)Buffer_data_slice, METH_VARARGS, ""},
{"eof", (PyCFunction)Buffer_eof, METH_VARARGS, ""},
{"pull_bytes", (PyCFunction)Buffer_pull_bytes, METH_VARARGS, "Pull bytes."},
{"pull_uint8", (PyCFunction)Buffer_pull_uint8, METH_VARARGS, "Pull an 8-bit unsigned integer."},
{"pull_uint16", (PyCFunction)Buffer_pull_uint16, METH_VARARGS, "Pull a 16-bit unsigned integer."},
{"pull_uint32", (PyCFunction)Buffer_pull_uint32, METH_VARARGS, "Pull a 32-bit unsigned integer."},
{"pull_uint64", (PyCFunction)Buffer_pull_uint64, METH_VARARGS, "Pull a 64-bit unsigned integer."},
{"pull_uint_var", (PyCFunction)Buffer_pull_uint_var, METH_VARARGS, "Pull a QUIC variable-length unsigned integer."},
{"push_bytes", (PyCFunction)Buffer_push_bytes, METH_VARARGS, "Push bytes."},
{"push_uint8", (PyCFunction)Buffer_push_uint8, METH_VARARGS, "Push an 8-bit unsigned integer."},
{"push_uint16", (PyCFunction)Buffer_push_uint16, METH_VARARGS, "Push a 16-bit unsigned integer."},
{"push_uint32", (PyCFunction)Buffer_push_uint32, METH_VARARGS, "Push a 32-bit unsigned integer."},
{"push_uint64", (PyCFunction)Buffer_push_uint64, METH_VARARGS, "Push a 64-bit unsigned integer."},
{"push_uint_var", (PyCFunction)Buffer_push_uint_var, METH_VARARGS, "Push a QUIC variable-length unsigned integer."},
{"seek", (PyCFunction)Buffer_seek, METH_VARARGS, ""},
{"tell", (PyCFunction)Buffer_tell, METH_VARARGS, ""},
{NULL}
};
static PyObject*
Buffer_capacity_getter(BufferObject* self, void *closure) {
return PyLong_FromSsize_t(self->end - self->base);
}
static PyObject*
Buffer_data_getter(BufferObject* self, void *closure) {
return PyBytes_FromStringAndSize((const char*)self->base, self->pos - self->base);
}
static PyGetSetDef Buffer_getset[] = {
{"capacity", (getter) Buffer_capacity_getter, NULL, "", NULL },
{"data", (getter) Buffer_data_getter, NULL, "", NULL },
{NULL}
};
static PyTypeObject BufferType = {
PyVarObject_HEAD_INIT(NULL, 0)
MODULE_NAME ".Buffer", /* tp_name */
sizeof(BufferObject), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)Buffer_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"Buffer objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
Buffer_methods, /* tp_methods */
0, /* tp_members */
Buffer_getset, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)Buffer_init, /* tp_init */
0, /* tp_alloc */
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
MODULE_NAME, /* m_name */
"A faster buffer.", /* m_doc */
-1, /* m_size */
NULL, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyMODINIT_FUNC
PyInit__buffer(void)
{
PyObject* m;
m = PyModule_Create(&moduledef);
if (m == NULL)
return NULL;
BufferReadError = PyErr_NewException(MODULE_NAME ".BufferReadError", PyExc_ValueError, NULL);
Py_INCREF(BufferReadError);
PyModule_AddObject(m, "BufferReadError", BufferReadError);
BufferWriteError = PyErr_NewException(MODULE_NAME ".BufferWriteError", PyExc_ValueError, NULL);
Py_INCREF(BufferWriteError);
PyModule_AddObject(m, "BufferWriteError", BufferWriteError);
BufferType.tp_new = PyType_GenericNew;
if (PyType_Ready(&BufferType) < 0)
return NULL;
Py_INCREF(&BufferType);
PyModule_AddObject(m, "Buffer", (PyObject *)&BufferType);
return m;
}

Просмотреть файл

@ -0,0 +1,27 @@
from typing import Optional
class BufferReadError(ValueError): ...
class BufferWriteError(ValueError): ...
class Buffer:
def __init__(self, capacity: Optional[int] = 0, data: Optional[bytes] = None): ...
@property
def capacity(self) -> int: ...
@property
def data(self) -> bytes: ...
def data_slice(self, start: int, end: int) -> bytes: ...
def eof(self) -> bool: ...
def seek(self, pos: int) -> None: ...
def tell(self) -> int: ...
def pull_bytes(self, length: int) -> bytes: ...
def pull_uint8(self) -> int: ...
def pull_uint16(self) -> int: ...
def pull_uint32(self) -> int: ...
def pull_uint64(self) -> int: ...
def pull_uint_var(self) -> int: ...
def push_bytes(self, value: bytes) -> None: ...
def push_uint8(self, value: int) -> None: ...
def push_uint16(self, value: int) -> None: ...
def push_uint32(self, v: int) -> None: ...
def push_uint64(self, v: int) -> None: ...
def push_uint_var(self, value: int) -> None: ...

Просмотреть файл

@ -0,0 +1,455 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#define MODULE_NAME "aioquic._crypto"
#define AEAD_KEY_LENGTH_MAX 32
#define AEAD_NONCE_LENGTH 12
#define AEAD_TAG_LENGTH 16
#define PACKET_LENGTH_MAX 1500
#define PACKET_NUMBER_LENGTH_MAX 4
#define SAMPLE_LENGTH 16
#define CHECK_RESULT(expr) \
if (!(expr)) { \
ERR_clear_error(); \
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
return NULL; \
}
#define CHECK_RESULT_CTOR(expr) \
if (!(expr)) { \
ERR_clear_error(); \
PyErr_SetString(CryptoError, "OpenSSL call failed"); \
return -1; \
}
static PyObject *CryptoError;
/* AEAD */
typedef struct {
PyObject_HEAD
EVP_CIPHER_CTX *decrypt_ctx;
EVP_CIPHER_CTX *encrypt_ctx;
unsigned char buffer[PACKET_LENGTH_MAX];
unsigned char key[AEAD_KEY_LENGTH_MAX];
unsigned char iv[AEAD_NONCE_LENGTH];
unsigned char nonce[AEAD_NONCE_LENGTH];
} AEADObject;
static EVP_CIPHER_CTX *
create_ctx(const EVP_CIPHER *cipher, int key_length, int operation)
{
EVP_CIPHER_CTX *ctx;
int res;
ctx = EVP_CIPHER_CTX_new();
CHECK_RESULT(ctx != 0);
res = EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, operation);
CHECK_RESULT(res != 0);
res = EVP_CIPHER_CTX_set_key_length(ctx, key_length);
CHECK_RESULT(res != 0);
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_CCM_SET_IVLEN, AEAD_NONCE_LENGTH, NULL);
CHECK_RESULT(res != 0);
return ctx;
}
static int
AEAD_init(AEADObject *self, PyObject *args, PyObject *kwargs)
{
const char *cipher_name;
const unsigned char *key, *iv;
Py_ssize_t cipher_name_len, key_len, iv_len;
if (!PyArg_ParseTuple(args, "y#y#y#", &cipher_name, &cipher_name_len, &key, &key_len, &iv, &iv_len))
return -1;
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
if (evp_cipher == 0) {
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
return -1;
}
if (key_len > AEAD_KEY_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid key length");
return -1;
}
if (iv_len > AEAD_NONCE_LENGTH) {
PyErr_SetString(CryptoError, "Invalid iv length");
return -1;
}
memcpy(self->key, key, key_len);
memcpy(self->iv, iv, iv_len);
self->decrypt_ctx = create_ctx(evp_cipher, key_len, 0);
CHECK_RESULT_CTOR(self->decrypt_ctx != 0);
self->encrypt_ctx = create_ctx(evp_cipher, key_len, 1);
CHECK_RESULT_CTOR(self->encrypt_ctx != 0);
return 0;
}
static void
AEAD_dealloc(AEADObject *self)
{
EVP_CIPHER_CTX_free(self->decrypt_ctx);
EVP_CIPHER_CTX_free(self->encrypt_ctx);
}
static PyObject*
AEAD_decrypt(AEADObject *self, PyObject *args)
{
const unsigned char *data, *associated;
Py_ssize_t data_len, associated_len;
int outlen, outlen2, res;
uint64_t pn;
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
return NULL;
if (data_len < AEAD_TAG_LENGTH || data_len > PACKET_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid payload length");
return NULL;
}
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
for (int i = 0; i < 8; ++i) {
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
}
res = EVP_CIPHER_CTX_ctrl(self->decrypt_ctx, EVP_CTRL_CCM_SET_TAG, AEAD_TAG_LENGTH, (void*)(data + (data_len - AEAD_TAG_LENGTH)));
CHECK_RESULT(res != 0);
res = EVP_CipherInit_ex(self->decrypt_ctx, NULL, NULL, self->key, self->nonce, 0);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->decrypt_ctx, NULL, &outlen, associated, associated_len);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->decrypt_ctx, self->buffer, &outlen, data, data_len - AEAD_TAG_LENGTH);
CHECK_RESULT(res != 0);
res = EVP_CipherFinal_ex(self->decrypt_ctx, NULL, &outlen2);
if (res == 0) {
PyErr_SetString(CryptoError, "Payload decryption failed");
return NULL;
}
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen);
}
static PyObject*
AEAD_encrypt(AEADObject *self, PyObject *args)
{
const unsigned char *data, *associated;
Py_ssize_t data_len, associated_len;
int outlen, outlen2, res;
uint64_t pn;
if (!PyArg_ParseTuple(args, "y#y#K", &data, &data_len, &associated, &associated_len, &pn))
return NULL;
if (data_len > PACKET_LENGTH_MAX) {
PyErr_SetString(CryptoError, "Invalid payload length");
return NULL;
}
memcpy(self->nonce, self->iv, AEAD_NONCE_LENGTH);
for (int i = 0; i < 8; ++i) {
self->nonce[AEAD_NONCE_LENGTH - 1 - i] ^= (uint8_t)(pn >> 8 * i);
}
res = EVP_CipherInit_ex(self->encrypt_ctx, NULL, NULL, self->key, self->nonce, 1);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->encrypt_ctx, NULL, &outlen, associated, associated_len);
CHECK_RESULT(res != 0);
res = EVP_CipherUpdate(self->encrypt_ctx, self->buffer, &outlen, data, data_len);
CHECK_RESULT(res != 0);
res = EVP_CipherFinal_ex(self->encrypt_ctx, NULL, &outlen2);
CHECK_RESULT(res != 0 && outlen2 == 0);
res = EVP_CIPHER_CTX_ctrl(self->encrypt_ctx, EVP_CTRL_CCM_GET_TAG, AEAD_TAG_LENGTH, self->buffer + outlen);
CHECK_RESULT(res != 0);
return PyBytes_FromStringAndSize((const char*)self->buffer, outlen + AEAD_TAG_LENGTH);
}
static PyMethodDef AEAD_methods[] = {
{"decrypt", (PyCFunction)AEAD_decrypt, METH_VARARGS, ""},
{"encrypt", (PyCFunction)AEAD_encrypt, METH_VARARGS, ""},
{NULL}
};
static PyTypeObject AEADType = {
PyVarObject_HEAD_INIT(NULL, 0)
MODULE_NAME ".AEAD", /* tp_name */
sizeof(AEADObject), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)AEAD_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"AEAD objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
AEAD_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)AEAD_init, /* tp_init */
0, /* tp_alloc */
};
/* HeaderProtection */
typedef struct {
PyObject_HEAD
EVP_CIPHER_CTX *ctx;
int is_chacha20;
unsigned char buffer[PACKET_LENGTH_MAX];
unsigned char mask[31];
unsigned char zero[5];
} HeaderProtectionObject;
static int
HeaderProtection_init(HeaderProtectionObject *self, PyObject *args, PyObject *kwargs)
{
const char *cipher_name;
const unsigned char *key;
Py_ssize_t cipher_name_len, key_len;
int res;
if (!PyArg_ParseTuple(args, "y#y#", &cipher_name, &cipher_name_len, &key, &key_len))
return -1;
const EVP_CIPHER *evp_cipher = EVP_get_cipherbyname(cipher_name);
if (evp_cipher == 0) {
PyErr_Format(CryptoError, "Invalid cipher name: %s", cipher_name);
return -1;
}
memset(self->mask, 0, sizeof(self->mask));
memset(self->zero, 0, sizeof(self->zero));
self->is_chacha20 = cipher_name_len == 8 && memcmp(cipher_name, "chacha20", 8) == 0;
self->ctx = EVP_CIPHER_CTX_new();
CHECK_RESULT_CTOR(self->ctx != 0);
res = EVP_CipherInit_ex(self->ctx, evp_cipher, NULL, NULL, NULL, 1);
CHECK_RESULT_CTOR(res != 0);
res = EVP_CIPHER_CTX_set_key_length(self->ctx, key_len);
CHECK_RESULT_CTOR(res != 0);
res = EVP_CipherInit_ex(self->ctx, NULL, NULL, key, NULL, 1);
CHECK_RESULT_CTOR(res != 0);
return 0;
}
static void
HeaderProtection_dealloc(HeaderProtectionObject *self)
{
EVP_CIPHER_CTX_free(self->ctx);
}
static int HeaderProtection_mask(HeaderProtectionObject *self, const unsigned char* sample)
{
int outlen;
if (self->is_chacha20) {
return EVP_CipherInit_ex(self->ctx, NULL, NULL, NULL, sample, 1) &&
EVP_CipherUpdate(self->ctx, self->mask, &outlen, self->zero, sizeof(self->zero));
} else {
return EVP_CipherUpdate(self->ctx, self->mask, &outlen, sample, SAMPLE_LENGTH);
}
}
static PyObject*
HeaderProtection_apply(HeaderProtectionObject *self, PyObject *args)
{
const unsigned char *header, *payload;
Py_ssize_t header_len, payload_len;
int res;
if (!PyArg_ParseTuple(args, "y#y#", &header, &header_len, &payload, &payload_len))
return NULL;
int pn_length = (header[0] & 0x03) + 1;
int pn_offset = header_len - pn_length;
res = HeaderProtection_mask(self, payload + PACKET_NUMBER_LENGTH_MAX - pn_length);
CHECK_RESULT(res != 0);
memcpy(self->buffer, header, header_len);
memcpy(self->buffer + header_len, payload, payload_len);
if (self->buffer[0] & 0x80) {
self->buffer[0] ^= self->mask[0] & 0x0F;
} else {
self->buffer[0] ^= self->mask[0] & 0x1F;
}
for (int i = 0; i < pn_length; ++i) {
self->buffer[pn_offset + i] ^= self->mask[1 + i];
}
return PyBytes_FromStringAndSize((const char*)self->buffer, header_len + payload_len);
}
static PyObject*
HeaderProtection_remove(HeaderProtectionObject *self, PyObject *args)
{
const unsigned char *packet;
Py_ssize_t packet_len;
int pn_offset, res;
if (!PyArg_ParseTuple(args, "y#I", &packet, &packet_len, &pn_offset))
return NULL;
res = HeaderProtection_mask(self, packet + pn_offset + PACKET_NUMBER_LENGTH_MAX);
CHECK_RESULT(res != 0);
memcpy(self->buffer, packet, pn_offset + PACKET_NUMBER_LENGTH_MAX);
if (self->buffer[0] & 0x80) {
self->buffer[0] ^= self->mask[0] & 0x0F;
} else {
self->buffer[0] ^= self->mask[0] & 0x1F;
}
int pn_length = (self->buffer[0] & 0x03) + 1;
uint32_t pn_truncated = 0;
for (int i = 0; i < pn_length; ++i) {
self->buffer[pn_offset + i] ^= self->mask[1 + i];
pn_truncated = self->buffer[pn_offset + i] | (pn_truncated << 8);
}
return Py_BuildValue("y#i", self->buffer, pn_offset + pn_length, pn_truncated);
}
static PyMethodDef HeaderProtection_methods[] = {
{"apply", (PyCFunction)HeaderProtection_apply, METH_VARARGS, ""},
{"remove", (PyCFunction)HeaderProtection_remove, METH_VARARGS, ""},
{NULL}
};
static PyTypeObject HeaderProtectionType = {
PyVarObject_HEAD_INIT(NULL, 0)
MODULE_NAME ".HeaderProtection", /* tp_name */
sizeof(HeaderProtectionObject), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)HeaderProtection_dealloc, /* tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
0, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"HeaderProtection objects", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
HeaderProtection_methods, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc)HeaderProtection_init, /* tp_init */
0, /* tp_alloc */
};
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
MODULE_NAME, /* m_name */
"A faster buffer.", /* m_doc */
-1, /* m_size */
NULL, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL, /* m_free */
};
PyMODINIT_FUNC
PyInit__crypto(void)
{
PyObject* m;
m = PyModule_Create(&moduledef);
if (m == NULL)
return NULL;
CryptoError = PyErr_NewException(MODULE_NAME ".CryptoError", PyExc_ValueError, NULL);
Py_INCREF(CryptoError);
PyModule_AddObject(m, "CryptoError", CryptoError);
AEADType.tp_new = PyType_GenericNew;
if (PyType_Ready(&AEADType) < 0)
return NULL;
Py_INCREF(&AEADType);
PyModule_AddObject(m, "AEAD", (PyObject *)&AEADType);
HeaderProtectionType.tp_new = PyType_GenericNew;
if (PyType_Ready(&HeaderProtectionType) < 0)
return NULL;
Py_INCREF(&HeaderProtectionType);
PyModule_AddObject(m, "HeaderProtection", (PyObject *)&HeaderProtectionType);
// ensure required ciphers are initialised
EVP_add_cipher(EVP_aes_128_ecb());
EVP_add_cipher(EVP_aes_128_gcm());
EVP_add_cipher(EVP_aes_256_ecb());
EVP_add_cipher(EVP_aes_256_gcm());
return m;
}

Просмотреть файл

@ -0,0 +1,17 @@
from typing import Tuple
class AEAD:
def __init__(self, cipher_name: bytes, key: bytes, iv: bytes): ...
def decrypt(
self, data: bytes, associated_data: bytes, packet_number: int
) -> bytes: ...
def encrypt(
self, data: bytes, associated_data: bytes, packet_number: int
) -> bytes: ...
class CryptoError(ValueError): ...
class HeaderProtection:
def __init__(self, cipher_name: bytes, key: bytes): ...
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
def remove(self, packet: bytes, encrypted_offset: int) -> Tuple[bytes, int]: ...

Просмотреть файл

@ -0,0 +1,7 @@
__author__ = "Jeremy Lainé"
__email__ = "jeremy.laine@m4x.org"
__license__ = "BSD"
__summary__ = "An implementation of QUIC and HTTP/3"
__title__ = "aioquic"
__uri__ = "https://github.com/aiortc/aioquic"
__version__ = "0.8.7"

Просмотреть файл

@ -0,0 +1,3 @@
from .client import connect # noqa
from .protocol import QuicConnectionProtocol # noqa
from .server import serve # noqa

Просмотреть файл

@ -0,0 +1,94 @@
import asyncio
import ipaddress
import socket
import sys
from typing import AsyncGenerator, Callable, Optional, cast
from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection
from ..tls import SessionTicketHandler
from .compat import asynccontextmanager
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["connect"]
@asynccontextmanager
async def connect(
host: str,
port: int,
*,
configuration: Optional[QuicConfiguration] = None,
create_protocol: Optional[Callable] = QuicConnectionProtocol,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
"""
Connect to a QUIC server at the given `host` and `port`.
:meth:`connect()` returns an awaitable. Awaiting it yields a
:class:`~aioquic.asyncio.QuicConnectionProtocol` which can be used to
create streams.
:func:`connect` also accepts the following optional arguments:
* ``configuration`` is a :class:`~aioquic.quic.configuration.QuicConfiguration`
configuration object.
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is received.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
* ``local_port`` is the UDP port number that this client wants to bind.
"""
loop = asyncio.get_event_loop()
local_host = "::"
# if host is not an IP address, pass it to enable SNI
try:
ipaddress.ip_address(host)
server_name = None
except ValueError:
server_name = host
# lookup remote address
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
addr = infos[0][4]
if len(addr) == 2:
# determine behaviour for IPv4
if sys.platform == "win32":
# on Windows, we must use an IPv4 socket to reach an IPv4 host
local_host = "0.0.0.0"
else:
# other platforms support dual-stack sockets
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
# prepare QUIC connection
if configuration is None:
configuration = QuicConfiguration(is_client=True)
if server_name is not None:
configuration.server_name = server_name
connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
)
# connect
_, protocol = await loop.create_datagram_endpoint(
lambda: create_protocol(connection, stream_handler=stream_handler),
local_addr=(local_host, local_port),
)
protocol = cast(QuicConnectionProtocol, protocol)
protocol.connect(addr)
if wait_connected:
await protocol.wait_connected()
try:
yield protocol
finally:
protocol.close()
await protocol.wait_closed()

Просмотреть файл

@ -0,0 +1,33 @@
from contextlib import ContextDecorator
from functools import wraps
try:
from contextlib import asynccontextmanager
except ImportError:
asynccontextmanager = None
class _AsyncGeneratorContextManager(ContextDecorator):
def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
self.func, self.args, self.kwds = func, args, kwds
self.__doc__ = func.__doc__
async def __aenter__(self):
return await self.gen.__anext__()
async def __aexit__(self, typ, value, traceback):
if typ is not None:
await self.gen.athrow(typ, value, traceback)
def _asynccontextmanager(func):
@wraps(func)
def helper(*args, **kwds):
return _AsyncGeneratorContextManager(func, args, kwds)
return helper
if asynccontextmanager is None:
asynccontextmanager = _asynccontextmanager

Просмотреть файл

@ -0,0 +1,240 @@
import asyncio
from typing import Any, Callable, Dict, Optional, Text, Tuple, Union, cast
from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection
QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self, quic: QuicConnection, stream_handler: Optional[QuicStreamHandler] = None
):
loop = asyncio.get_event_loop()
self._closed = asyncio.Event()
self._connected = False
self._connected_waiter: Optional[asyncio.Future[None]] = None
self._loop = loop
self._ping_waiters: Dict[int, asyncio.Future[None]] = {}
self._quic = quic
self._stream_readers: Dict[int, asyncio.StreamReader] = {}
self._timer: Optional[asyncio.TimerHandle] = None
self._timer_at: Optional[float] = None
self._transmit_task: Optional[asyncio.Handle] = None
self._transport: Optional[asyncio.DatagramTransport] = None
# callbacks
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
self._connection_terminated_handler: Callable[[], None] = lambda: None
if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None
def change_connection_id(self) -> None:
"""
Change the connection ID used to communicate with the peer.
The previous connection ID will be retired.
"""
self._quic.change_connection_id()
self.transmit()
def close(self) -> None:
"""
Close the connection.
"""
self._quic.close()
self.transmit()
def connect(self, addr: NetworkAddress) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
"""
self._quic.connect(addr, now=self._loop.time())
self.transmit()
async def create_stream(
self, is_unidirectional: bool = False
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of :class:`asyncio.StreamReader`
and :class:`asyncio.StreamWriter` classes.
"""
stream_id = self._quic.get_next_available_stream_id(
is_unidirectional=is_unidirectional
)
return self._create_stream(stream_id)
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._quic.request_key_update()
self.transmit()
async def ping(self) -> None:
"""
Ping the peer and wait for the response.
"""
waiter = self._loop.create_future()
uid = id(waiter)
self._ping_waiters[uid] = waiter
self._quic.send_ping(uid)
self.transmit()
await asyncio.shield(waiter)
def transmit(self) -> None:
"""
Send pending datagrams to the peer and arm the timer if needed.
"""
self._transmit_task = None
# send datagrams
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
self._transport.sendto(data, addr)
# re-arm timer
timer_at = self._quic.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at
async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()
async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
assert self._connected_waiter is None, "already awaiting connected"
if not self._connected:
self._connected_waiter = self._loop.create_future()
await asyncio.shield(self._connected_waiter)
# asyncio.Transport
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
self._process_events()
self.transmit()
# overridable
def quic_event_received(self, event: events.QuicEvent) -> None:
"""
Called when a QUIC event is received.
Reimplement this in your subclass to handle the events.
"""
# FIXME: move this to a subclass
if isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()
# private
def _create_stream(
self, stream_id: int
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
self._stream_readers[stream_id] = reader
return reader, writer
def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())
self._timer = None
self._timer_at = None
self._quic.handle_timer(now=now)
self._process_events()
self.transmit()
def _process_events(self) -> None:
event = self._quic.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
self._connection_terminated_handler()
# abort connection waiter
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected_waiter = None
waiter.set_exception(ConnectionError)
# abort ping waiters
for waiter in self._ping_waiters.values():
waiter.set_exception(ConnectionError)
self._ping_waiters.clear()
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected = True
self._connected_waiter = None
waiter.set_result(None)
elif isinstance(event, events.PingAcknowledged):
waiter = self._ping_waiters.pop(event.uid, None)
if waiter is not None:
waiter.set_result(None)
self.quic_event_received(event)
event = self._quic.next_event()
def _transmit_soon(self) -> None:
if self._transmit_task is None:
self._transmit_task = self._loop.call_soon(self.transmit)
class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
self.protocol = protocol
self.stream_id = stream_id
def can_write_eof(self) -> bool:
return True
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "stream_id":
return self.stream_id
def write(self, data):
self.protocol._quic.send_stream_data(self.stream_id, data)
self.protocol._transmit_soon()
def write_eof(self):
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._transmit_soon()

Просмотреть файл

@ -0,0 +1,210 @@
import asyncio
import os
from functools import partial
from typing import Callable, Dict, Optional, Text, Union, cast
from ..buffer import Buffer
from ..quic.configuration import QuicConfiguration
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import (
PACKET_TYPE_INITIAL,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from ..quic.retry import QuicRetryTokenHandler
from ..tls import SessionTicketFetcher, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["serve"]
class QuicServer(asyncio.DatagramProtocol):
def __init__(
self,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stateless_retry: bool = False,
stream_handler: Optional[QuicStreamHandler] = None,
) -> None:
self._configuration = configuration
self._create_protocol = create_protocol
self._loop = asyncio.get_event_loop()
self._protocols: Dict[bytes, QuicConnectionProtocol] = {}
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: Optional[asyncio.DatagramTransport] = None
self._stream_handler = stream_handler
if stateless_retry:
self._retry = QuicRetryTokenHandler()
else:
self._retry = None
def close(self):
for protocol in set(self._protocols.values()):
protocol.close()
self._protocols.clear()
self._transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> None:
data = cast(bytes, data)
buf = Buffer(data=data)
try:
header = pull_quic_header(
buf, host_cid_length=self._configuration.connection_id_length
)
except ValueError:
return
# version negotiation
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
self._transport.sendto(
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self._configuration.supported_versions,
),
addr,
)
return
protocol = self._protocols.get(header.destination_cid, None)
original_connection_id: Optional[bytes] = None
if (
protocol is None
and len(data) >= 1200
and header.packet_type == PACKET_TYPE_INITIAL
):
# stateless retry
if self._retry is not None:
if not header.token:
# create a retry token
self._transport.sendto(
encode_quic_retry(
version=header.version,
source_cid=os.urandom(8),
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self._retry.create_token(
addr, header.destination_cid
),
),
addr,
)
return
else:
# validate retry token
try:
original_connection_id = self._retry.validate_token(
addr, header.token
)
except ValueError:
return
# create new connection
connection = QuicConnection(
configuration=self._configuration,
logger_connection_id=original_connection_id or header.destination_cid,
original_connection_id=original_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
)
protocol = self._create_protocol(
connection, stream_handler=self._stream_handler
)
protocol.connection_made(self._transport)
# register callbacks
protocol._connection_id_issued_handler = partial(
self._connection_id_issued, protocol=protocol
)
protocol._connection_id_retired_handler = partial(
self._connection_id_retired, protocol=protocol
)
protocol._connection_terminated_handler = partial(
self._connection_terminated, protocol=protocol
)
self._protocols[header.destination_cid] = protocol
self._protocols[connection.host_cid] = protocol
if protocol is not None:
protocol.datagram_received(data, addr)
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
self._protocols[cid] = protocol
def _connection_id_retired(
self, cid: bytes, protocol: QuicConnectionProtocol
) -> None:
assert self._protocols[cid] == protocol
del self._protocols[cid]
def _connection_terminated(self, protocol: QuicConnectionProtocol):
for cid, proto in list(self._protocols.items()):
if proto == protocol:
del self._protocols[cid]
async def serve(
host: str,
port: int,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: Optional[SessionTicketFetcher] = None,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stateless_retry: bool = False,
stream_handler: QuicStreamHandler = None,
) -> QuicServer:
"""
Start a QUIC server at the given `host` and `port`.
:func:`serve` requires a :class:`~aioquic.quic.configuration.QuicConfiguration`
containing TLS certificate and private key as the ``configuration`` argument.
:func:`serve` also accepts the following optional arguments:
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~aioquic.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~aioquic.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
engine when a session ticket is presented by the peer. It should return
the session ticket with the specified ID or `None` if it is not found.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is issued. It should store the session
ticket for future lookup.
* ``stateless_retry`` specifies whether a stateless retry should be
performed prior to handling new connections.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
"""
loop = asyncio.get_event_loop()
_, protocol = await loop.create_datagram_endpoint(
lambda: QuicServer(
configuration=configuration,
create_protocol=create_protocol,
session_ticket_fetcher=session_ticket_fetcher,
session_ticket_handler=session_ticket_handler,
stateless_retry=stateless_retry,
stream_handler=stream_handler,
),
local_addr=(host, port),
)
return cast(QuicServer, protocol)

Просмотреть файл

@ -0,0 +1,29 @@
from ._buffer import Buffer, BufferReadError, BufferWriteError # noqa
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
def encode_uint_var(value: int) -> bytes:
"""
Encode a variable-length unsigned integer.
"""
buf = Buffer(capacity=8)
buf.push_uint_var(value)
return buf.data
def size_uint_var(value: int) -> int:
"""
Return the number of bytes required to encode the given value
as a QUIC variable-length unsigned integer.
"""
if value <= 0x3F:
return 1
elif value <= 0x3FFF:
return 2
elif value <= 0x3FFFFFFF:
return 4
elif value <= 0x3FFFFFFFFFFFFFFF:
return 8
else:
raise ValueError("Integer is too big for a variable-length integer")

Просмотреть файл

Просмотреть файл

@ -0,0 +1,63 @@
from typing import Dict, List
from aioquic.h3.events import DataReceived, H3Event, Headers, HeadersReceived
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent, StreamDataReceived
H0_ALPN = ["hq-27", "hq-26", "hq-25"]
class H0Connection:
"""
An HTTP/0.9 connection object.
"""
def __init__(self, quic: QuicConnection):
self._headers_received: Dict[int, bool] = {}
self._is_client = quic.configuration.is_client
self._quic = quic
def handle_event(self, event: QuicEvent) -> List[H3Event]:
http_events: List[H3Event] = []
if isinstance(event, StreamDataReceived) and (event.stream_id % 4) == 0:
data = event.data
if not self._headers_received.get(event.stream_id, False):
if self._is_client:
http_events.append(
HeadersReceived(
headers=[], stream_ended=False, stream_id=event.stream_id
)
)
else:
method, path = data.rstrip().split(b" ", 1)
http_events.append(
HeadersReceived(
headers=[(b":method", method), (b":path", path)],
stream_ended=False,
stream_id=event.stream_id,
)
)
data = b""
self._headers_received[event.stream_id] = True
http_events.append(
DataReceived(
data=data, stream_ended=event.end_stream, stream_id=event.stream_id
)
)
return http_events
def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None:
self._quic.send_stream_data(stream_id, data, end_stream)
def send_headers(
self, stream_id: int, headers: Headers, end_stream: bool = False
) -> None:
if self._is_client:
headers_dict = dict(headers)
data = headers_dict[b":method"] + b" " + headers_dict[b":path"] + b"\r\n"
else:
data = b""
self._quic.send_stream_data(stream_id, data, end_stream)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,776 @@
import logging
from enum import Enum, IntEnum
from typing import Dict, List, Optional, Set
import pylsqpack
from aioquic.buffer import Buffer, BufferReadError, encode_uint_var
from aioquic.h3.events import (
DataReceived,
H3Event,
Headers,
HeadersReceived,
PushPromiseReceived,
)
from aioquic.h3.exceptions import NoAvailablePushIDError
from aioquic.quic.connection import QuicConnection, stream_is_unidirectional
from aioquic.quic.events import QuicEvent, StreamDataReceived
from aioquic.quic.logger import QuicLoggerTrace
logger = logging.getLogger("http3")
H3_ALPN = ["h3-27", "h3-26", "h3-25"]
class ErrorCode(IntEnum):
HTTP_NO_ERROR = 0x100
HTTP_GENERAL_PROTOCOL_ERROR = 0x101
HTTP_INTERNAL_ERROR = 0x102
HTTP_STREAM_CREATION_ERROR = 0x103
HTTP_CLOSED_CRITICAL_STREAM = 0x104
HTTP_FRAME_UNEXPECTED = 0x105
HTTP_FRAME_ERROR = 0x106
HTTP_EXCESSIVE_LOAD = 0x107
HTTP_ID_ERROR = 0x108
HTTP_SETTINGS_ERROR = 0x109
HTTP_MISSING_SETTINGS = 0x10A
HTTP_REQUEST_REJECTED = 0x10B
HTTP_REQUEST_CANCELLED = 0x10C
HTTP_REQUEST_INCOMPLETE = 0x10D
HTTP_EARLY_RESPONSE = 0x10E
HTTP_CONNECT_ERROR = 0x10F
HTTP_VERSION_FALLBACK = 0x110
HTTP_QPACK_DECOMPRESSION_FAILED = 0x200
HTTP_QPACK_ENCODER_STREAM_ERROR = 0x201
HTTP_QPACK_DECODER_STREAM_ERROR = 0x202
class FrameType(IntEnum):
DATA = 0x0
HEADERS = 0x1
PRIORITY = 0x2
CANCEL_PUSH = 0x3
SETTINGS = 0x4
PUSH_PROMISE = 0x5
GOAWAY = 0x7
MAX_PUSH_ID = 0xD
DUPLICATE_PUSH = 0xE
class HeadersState(Enum):
INITIAL = 0
AFTER_HEADERS = 1
AFTER_TRAILERS = 2
class Setting(IntEnum):
QPACK_MAX_TABLE_CAPACITY = 1
SETTINGS_MAX_HEADER_LIST_SIZE = 6
QPACK_BLOCKED_STREAMS = 7
SETTINGS_NUM_PLACEHOLDERS = 9
class StreamType(IntEnum):
CONTROL = 0
PUSH = 1
QPACK_ENCODER = 2
QPACK_DECODER = 3
class ProtocolError(Exception):
"""
Base class for protocol errors.
These errors are not exposed to the API user, they are handled
in :meth:`H3Connection.handle_event`.
"""
error_code = ErrorCode.HTTP_GENERAL_PROTOCOL_ERROR
def __init__(self, reason_phrase: str = ""):
self.reason_phrase = reason_phrase
class QpackDecompressionFailed(ProtocolError):
error_code = ErrorCode.HTTP_QPACK_DECOMPRESSION_FAILED
class QpackDecoderStreamError(ProtocolError):
error_code = ErrorCode.HTTP_QPACK_DECODER_STREAM_ERROR
class QpackEncoderStreamError(ProtocolError):
error_code = ErrorCode.HTTP_QPACK_ENCODER_STREAM_ERROR
class StreamCreationError(ProtocolError):
error_code = ErrorCode.HTTP_STREAM_CREATION_ERROR
class FrameUnexpected(ProtocolError):
error_code = ErrorCode.HTTP_FRAME_UNEXPECTED
def encode_frame(frame_type: int, frame_data: bytes) -> bytes:
frame_length = len(frame_data)
buf = Buffer(capacity=frame_length + 16)
buf.push_uint_var(frame_type)
buf.push_uint_var(frame_length)
buf.push_bytes(frame_data)
return buf.data
def encode_settings(settings: Dict[int, int]) -> bytes:
buf = Buffer(capacity=1024)
for setting, value in settings.items():
buf.push_uint_var(setting)
buf.push_uint_var(value)
return buf.data
def parse_max_push_id(data: bytes) -> int:
buf = Buffer(data=data)
max_push_id = buf.pull_uint_var()
assert buf.eof()
return max_push_id
def parse_settings(data: bytes) -> Dict[int, int]:
buf = Buffer(data=data)
settings = []
while not buf.eof():
setting = buf.pull_uint_var()
value = buf.pull_uint_var()
settings.append((setting, value))
return dict(settings)
def qlog_encode_data_frame(byte_length: int, stream_id: int) -> Dict:
return {
"byte_length": str(byte_length),
"frame": {"frame_type": "data"},
"stream_id": str(stream_id),
}
def qlog_encode_headers(headers: Headers) -> List[Dict]:
return [
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
]
def qlog_encode_headers_frame(
byte_length: int, headers: Headers, stream_id: int
) -> Dict:
return {
"byte_length": str(byte_length),
"frame": {"frame_type": "headers", "headers": qlog_encode_headers(headers)},
"stream_id": str(stream_id),
}
def qlog_encode_push_promise_frame(
byte_length: int, headers: Headers, push_id: int, stream_id: int
) -> Dict:
return {
"byte_length": str(byte_length),
"frame": {
"frame_type": "push_promise",
"headers": qlog_encode_headers(headers),
"push_id": str(push_id),
},
"stream_id": str(stream_id),
}
class H3Stream:
def __init__(self, stream_id: int) -> None:
self.blocked = False
self.blocked_frame_size: Optional[int] = None
self.buffer = b""
self.ended = False
self.frame_size: Optional[int] = None
self.frame_type: Optional[int] = None
self.headers_recv_state: HeadersState = HeadersState.INITIAL
self.headers_send_state: HeadersState = HeadersState.INITIAL
self.push_id: Optional[int] = None
self.stream_id = stream_id
self.stream_type: Optional[int] = None
class H3Connection:
"""
A low-level HTTP/3 connection object.
:param quic: A :class:`~aioquic.connection.QuicConnection` instance.
"""
def __init__(self, quic: QuicConnection):
self._max_table_capacity = 4096
self._blocked_streams = 16
self._is_client = quic.configuration.is_client
self._is_done = False
self._quic = quic
self._quic_logger: Optional[QuicLoggerTrace] = quic._quic_logger
self._decoder = pylsqpack.Decoder(
self._max_table_capacity, self._blocked_streams
)
self._decoder_bytes_received = 0
self._decoder_bytes_sent = 0
self._encoder = pylsqpack.Encoder()
self._encoder_bytes_received = 0
self._encoder_bytes_sent = 0
self._stream: Dict[int, H3Stream] = {}
self._max_push_id: Optional[int] = 8 if self._is_client else None
self._next_push_id: int = 0
self._local_control_stream_id: Optional[int] = None
self._local_decoder_stream_id: Optional[int] = None
self._local_encoder_stream_id: Optional[int] = None
self._peer_control_stream_id: Optional[int] = None
self._peer_decoder_stream_id: Optional[int] = None
self._peer_encoder_stream_id: Optional[int] = None
self._init_connection()
def handle_event(self, event: QuicEvent) -> List[H3Event]:
"""
Handle a QUIC event and return a list of HTTP events.
:param event: The QUIC event to handle.
"""
if isinstance(event, StreamDataReceived) and not self._is_done:
stream_id = event.stream_id
stream = self._get_or_create_stream(stream_id)
try:
if stream_id % 4 == 0:
return self._receive_request_or_push_data(
stream, event.data, event.end_stream
)
elif stream_is_unidirectional(stream_id):
return self._receive_stream_data_uni(
stream, event.data, event.end_stream
)
except ProtocolError as exc:
self._is_done = True
self._quic.close(
error_code=exc.error_code, reason_phrase=exc.reason_phrase
)
return []
def send_push_promise(self, stream_id: int, headers: Headers) -> int:
"""
Send a push promise related to the specified stream.
Returns the stream ID on which headers and data can be sent.
:param stream_id: The stream ID on which to send the data.
:param headers: The HTTP request headers for this push.
"""
assert not self._is_client, "Only servers may send a push promise."
if self._max_push_id is None or self._next_push_id >= self._max_push_id:
raise NoAvailablePushIDError
# send push promise
push_id = self._next_push_id
self._next_push_id += 1
self._quic.send_stream_data(
stream_id,
encode_frame(
FrameType.PUSH_PROMISE,
encode_uint_var(push_id) + self._encode_headers(stream_id, headers),
),
)
#  create push stream
push_stream_id = self._create_uni_stream(StreamType.PUSH)
self._quic.send_stream_data(push_stream_id, encode_uint_var(push_id))
return push_stream_id
def send_data(self, stream_id: int, data: bytes, end_stream: bool) -> None:
"""
Send data on the given stream.
To retrieve datagram which need to be sent over the network call the QUIC
connection's :meth:`~aioquic.connection.QuicConnection.datagrams_to_send`
method.
:param stream_id: The stream ID on which to send the data.
:param data: The data to send.
:param end_stream: Whether to end the stream.
"""
# check DATA frame is allowed
stream = self._get_or_create_stream(stream_id)
if stream.headers_send_state != HeadersState.AFTER_HEADERS:
raise FrameUnexpected("DATA frame is not allowed in this state")
# log frame
if self._quic_logger is not None:
self._quic_logger.log_event(
category="http",
event="frame_created",
data=qlog_encode_data_frame(byte_length=len(data), stream_id=stream_id),
)
self._quic.send_stream_data(
stream_id, encode_frame(FrameType.DATA, data), end_stream
)
def send_headers(
self, stream_id: int, headers: Headers, end_stream: bool = False
) -> None:
"""
Send headers on the given stream.
To retrieve datagram which need to be sent over the network call the QUIC
connection's :meth:`~aioquic.connection.QuicConnection.datagrams_to_send`
method.
:param stream_id: The stream ID on which to send the headers.
:param headers: The HTTP headers to send.
:param end_stream: Whether to end the stream.
"""
# check HEADERS frame is allowed
stream = self._get_or_create_stream(stream_id)
if stream.headers_send_state == HeadersState.AFTER_TRAILERS:
raise FrameUnexpected("HEADERS frame is not allowed in this state")
frame_data = self._encode_headers(stream_id, headers)
# log frame
if self._quic_logger is not None:
self._quic_logger.log_event(
category="http",
event="frame_created",
data=qlog_encode_headers_frame(
byte_length=len(frame_data), headers=headers, stream_id=stream_id
),
)
# update state and send headers
if stream.headers_send_state == HeadersState.INITIAL:
stream.headers_send_state = HeadersState.AFTER_HEADERS
else:
stream.headers_send_state = HeadersState.AFTER_TRAILERS
self._quic.send_stream_data(
stream_id, encode_frame(FrameType.HEADERS, frame_data), end_stream
)
def _create_uni_stream(self, stream_type: int) -> int:
"""
Create an unidirectional stream of the given type.
"""
stream_id = self._quic.get_next_available_stream_id(is_unidirectional=True)
self._quic.send_stream_data(stream_id, encode_uint_var(stream_type))
return stream_id
def _decode_headers(self, stream_id: int, frame_data: Optional[bytes]) -> Headers:
"""
Decode a HEADERS block and send decoder updates on the decoder stream.
This is called with frame_data=None when a stream becomes unblocked.
"""
try:
if frame_data is None:
decoder, headers = self._decoder.resume_header(stream_id)
else:
decoder, headers = self._decoder.feed_header(stream_id, frame_data)
self._decoder_bytes_sent += len(decoder)
self._quic.send_stream_data(self._local_decoder_stream_id, decoder)
except pylsqpack.DecompressionFailed as exc:
raise QpackDecompressionFailed() from exc
return headers
def _encode_headers(self, stream_id: int, headers: Headers) -> bytes:
"""
Encode a HEADERS block and send encoder updates on the encoder stream.
"""
encoder, frame_data = self._encoder.encode(stream_id, headers)
self._encoder_bytes_sent += len(encoder)
self._quic.send_stream_data(self._local_encoder_stream_id, encoder)
return frame_data
def _get_or_create_stream(self, stream_id: int) -> H3Stream:
if stream_id not in self._stream:
self._stream[stream_id] = H3Stream(stream_id)
return self._stream[stream_id]
def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None:
"""
Handle a frame received on the peer's control stream.
"""
if frame_type == FrameType.SETTINGS:
settings = parse_settings(frame_data)
encoder = self._encoder.apply_settings(
max_table_capacity=settings.get(Setting.QPACK_MAX_TABLE_CAPACITY, 0),
blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0),
)
self._quic.send_stream_data(self._local_encoder_stream_id, encoder)
elif frame_type == FrameType.MAX_PUSH_ID:
if self._is_client:
raise FrameUnexpected("Servers must not send MAX_PUSH_ID")
self._max_push_id = parse_max_push_id(frame_data)
elif frame_type in (
FrameType.DATA,
FrameType.HEADERS,
FrameType.PUSH_PROMISE,
FrameType.DUPLICATE_PUSH,
):
raise FrameUnexpected("Invalid frame type on control stream")
def _handle_request_or_push_frame(
self,
frame_type: int,
frame_data: Optional[bytes],
stream: H3Stream,
stream_ended: bool,
) -> List[H3Event]:
"""
Handle a frame received on a request or push stream.
"""
http_events: List[H3Event] = []
if frame_type == FrameType.DATA:
# check DATA frame is allowed
if stream.headers_recv_state != HeadersState.AFTER_HEADERS:
raise FrameUnexpected("DATA frame is not allowed in this state")
if stream_ended or frame_data:
http_events.append(
DataReceived(
data=frame_data,
push_id=stream.push_id,
stream_ended=stream_ended,
stream_id=stream.stream_id,
)
)
elif frame_type == FrameType.HEADERS:
# check HEADERS frame is allowed
if stream.headers_recv_state == HeadersState.AFTER_TRAILERS:
raise FrameUnexpected("HEADERS frame is not allowed in this state")
# try to decode HEADERS, may raise pylsqpack.StreamBlocked
headers = self._decode_headers(stream.stream_id, frame_data)
# log frame
if self._quic_logger is not None:
self._quic_logger.log_event(
category="http",
event="frame_parsed",
data=qlog_encode_headers_frame(
byte_length=stream.blocked_frame_size
if frame_data is None
else len(frame_data),
headers=headers,
stream_id=stream.stream_id,
),
)
# update state and emit headers
if stream.headers_recv_state == HeadersState.INITIAL:
stream.headers_recv_state = HeadersState.AFTER_HEADERS
else:
stream.headers_recv_state = HeadersState.AFTER_TRAILERS
http_events.append(
HeadersReceived(
headers=headers,
push_id=stream.push_id,
stream_id=stream.stream_id,
stream_ended=stream_ended,
)
)
elif stream.frame_type == FrameType.PUSH_PROMISE and stream.push_id is None:
if not self._is_client:
raise FrameUnexpected("Clients must not send PUSH_PROMISE")
frame_buf = Buffer(data=frame_data)
push_id = frame_buf.pull_uint_var()
headers = self._decode_headers(
stream.stream_id, frame_data[frame_buf.tell() :]
)
# log frame
if self._quic_logger is not None:
self._quic_logger.log_event(
category="http",
event="frame_parsed",
data=qlog_encode_push_promise_frame(
byte_length=len(frame_data),
headers=headers,
push_id=push_id,
stream_id=stream.stream_id,
),
)
# emit event
http_events.append(
PushPromiseReceived(
headers=headers, push_id=push_id, stream_id=stream.stream_id
)
)
elif frame_type in (
FrameType.PRIORITY,
FrameType.CANCEL_PUSH,
FrameType.SETTINGS,
FrameType.PUSH_PROMISE,
FrameType.GOAWAY,
FrameType.MAX_PUSH_ID,
FrameType.DUPLICATE_PUSH,
):
raise FrameUnexpected(
"Invalid frame type on request stream"
if stream.push_id is None
else "Invalid frame type on push stream"
)
return http_events
def _init_connection(self) -> None:
# send our settings
self._local_control_stream_id = self._create_uni_stream(StreamType.CONTROL)
self._quic.send_stream_data(
self._local_control_stream_id,
encode_frame(
FrameType.SETTINGS,
encode_settings(
{
Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity,
Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams,
}
),
),
)
if self._is_client and self._max_push_id is not None:
self._quic.send_stream_data(
self._local_control_stream_id,
encode_frame(FrameType.MAX_PUSH_ID, encode_uint_var(self._max_push_id)),
)
# create encoder and decoder streams
self._local_encoder_stream_id = self._create_uni_stream(
StreamType.QPACK_ENCODER
)
self._local_decoder_stream_id = self._create_uni_stream(
StreamType.QPACK_DECODER
)
def _receive_request_or_push_data(
self, stream: H3Stream, data: bytes, stream_ended: bool
) -> List[H3Event]:
"""
Handle data received on a request or push stream.
"""
http_events: List[H3Event] = []
stream.buffer += data
if stream_ended:
stream.ended = True
if stream.blocked:
return http_events
# shortcut for DATA frame fragments
if (
stream.frame_type == FrameType.DATA
and stream.frame_size is not None
and len(stream.buffer) < stream.frame_size
):
http_events.append(
DataReceived(
data=stream.buffer,
push_id=stream.push_id,
stream_id=stream.stream_id,
stream_ended=False,
)
)
stream.frame_size -= len(stream.buffer)
stream.buffer = b""
return http_events
# handle lone FIN
if stream_ended and not stream.buffer:
http_events.append(
DataReceived(
data=b"",
push_id=stream.push_id,
stream_id=stream.stream_id,
stream_ended=True,
)
)
return http_events
buf = Buffer(data=stream.buffer)
consumed = 0
while not buf.eof():
# fetch next frame header
if stream.frame_size is None:
try:
stream.frame_type = buf.pull_uint_var()
stream.frame_size = buf.pull_uint_var()
except BufferReadError:
break
consumed = buf.tell()
# log frame
if (
self._quic_logger is not None
and stream.frame_type == FrameType.DATA
):
self._quic_logger.log_event(
category="http",
event="frame_parsed",
data=qlog_encode_data_frame(
byte_length=stream.frame_size, stream_id=stream.stream_id
),
)
# check how much data is available
chunk_size = min(stream.frame_size, buf.capacity - consumed)
if stream.frame_type != FrameType.DATA and chunk_size < stream.frame_size:
break
# read available data
frame_data = buf.pull_bytes(chunk_size)
consumed = buf.tell()
# detect end of frame
stream.frame_size -= chunk_size
if not stream.frame_size:
stream.frame_size = None
try:
http_events.extend(
self._handle_request_or_push_frame(
frame_type=stream.frame_type,
frame_data=frame_data,
stream=stream,
stream_ended=stream.ended and buf.eof(),
)
)
except pylsqpack.StreamBlocked:
stream.blocked = True
stream.blocked_frame_size = len(frame_data)
break
# remove processed data from buffer
stream.buffer = stream.buffer[consumed:]
return http_events
def _receive_stream_data_uni(
self, stream: H3Stream, data: bytes, stream_ended: bool
) -> List[H3Event]:
http_events: List[H3Event] = []
stream.buffer += data
if stream_ended:
stream.ended = True
buf = Buffer(data=stream.buffer)
consumed = 0
unblocked_streams: Set[int] = set()
while stream.stream_type == StreamType.PUSH or not buf.eof():
# fetch stream type for unidirectional streams
if stream.stream_type is None:
try:
stream.stream_type = buf.pull_uint_var()
except BufferReadError:
break
consumed = buf.tell()
# check unicity
if stream.stream_type == StreamType.CONTROL:
if self._peer_control_stream_id is not None:
raise StreamCreationError("Only one control stream is allowed")
self._peer_control_stream_id = stream.stream_id
elif stream.stream_type == StreamType.QPACK_DECODER:
if self._peer_decoder_stream_id is not None:
raise StreamCreationError(
"Only one QPACK decoder stream is allowed"
)
self._peer_decoder_stream_id = stream.stream_id
elif stream.stream_type == StreamType.QPACK_ENCODER:
if self._peer_encoder_stream_id is not None:
raise StreamCreationError(
"Only one QPACK encoder stream is allowed"
)
self._peer_encoder_stream_id = stream.stream_id
if stream.stream_type == StreamType.CONTROL:
# fetch next frame
try:
frame_type = buf.pull_uint_var()
frame_length = buf.pull_uint_var()
frame_data = buf.pull_bytes(frame_length)
except BufferReadError:
break
consumed = buf.tell()
self._handle_control_frame(frame_type, frame_data)
elif stream.stream_type == StreamType.PUSH:
# fetch push id
if stream.push_id is None:
try:
stream.push_id = buf.pull_uint_var()
except BufferReadError:
break
consumed = buf.tell()
# remove processed data from buffer
stream.buffer = stream.buffer[consumed:]
return self._receive_request_or_push_data(stream, b"", stream_ended)
elif stream.stream_type == StreamType.QPACK_DECODER:
# feed unframed data to decoder
data = buf.pull_bytes(buf.capacity - buf.tell())
consumed = buf.tell()
try:
self._encoder.feed_decoder(data)
except pylsqpack.DecoderStreamError as exc:
raise QpackDecoderStreamError() from exc
self._decoder_bytes_received += len(data)
elif stream.stream_type == StreamType.QPACK_ENCODER:
# feed unframed data to encoder
data = buf.pull_bytes(buf.capacity - buf.tell())
consumed = buf.tell()
try:
unblocked_streams.update(self._decoder.feed_encoder(data))
except pylsqpack.EncoderStreamError as exc:
raise QpackEncoderStreamError() from exc
self._encoder_bytes_received += len(data)
else:
# unknown stream type, discard data
buf.seek(buf.capacity)
consumed = buf.tell()
# remove processed data from buffer
stream.buffer = stream.buffer[consumed:]
# process unblocked streams
for stream_id in unblocked_streams:
stream = self._stream[stream_id]
# resume headers
http_events.extend(
self._handle_request_or_push_frame(
frame_type=FrameType.HEADERS,
frame_data=None,
stream=stream,
stream_ended=stream.ended and not stream.buffer,
)
)
stream.blocked = False
stream.blocked_frame_size = None
# resume processing
if stream.buffer:
http_events.extend(
self._receive_request_or_push_data(stream, b"", stream.ended)
)
return http_events

Просмотреть файл

@ -0,0 +1,66 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
Headers = List[Tuple[bytes, bytes]]
class H3Event:
"""
Base class for HTTP/3 events.
"""
@dataclass
class DataReceived(H3Event):
"""
The DataReceived event is fired whenever data is received on a stream from
the remote peer.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: Optional[int] = None
"The Push ID or `None` if this is not a push."
@dataclass
class HeadersReceived(H3Event):
"""
The HeadersReceived event is fired whenever headers are received.
"""
headers: Headers
"The headers."
stream_id: int
"The ID of the stream the headers were received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: Optional[int] = None
"The Push ID or `None` if this is not a push."
@dataclass
class PushPromiseReceived(H3Event):
"""
The PushedStreamReceived event is fired whenever a pushed stream has been
received from the remote peer.
"""
headers: Headers
"The request headers."
push_id: int
"The Push ID of the push promise."
stream_id: int
"The Stream ID of the stream that the push is related to."

Просмотреть файл

@ -0,0 +1,10 @@
class H3Error(Exception):
"""
Base class for HTTP/3 exceptions.
"""
class NoAvailablePushIDError(H3Error):
"""
There are no available push IDs left.
"""

Просмотреть файл

@ -0,0 +1 @@
Marker

Просмотреть файл

Просмотреть файл

@ -0,0 +1,124 @@
from dataclasses import dataclass, field
from os import PathLike
from typing import Any, List, Optional, TextIO, Union
from ..tls import SessionTicket, load_pem_private_key, load_pem_x509_certificates
from .logger import QuicLogger
from .packet import QuicProtocolVersion
@dataclass
class QuicConfiguration:
"""
A QUIC configuration.
"""
alpn_protocols: Optional[List[str]] = None
"""
A list of supported ALPN protocols.
"""
connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
"""
idle_timeout: float = 60.0
"""
The idle timeout in seconds.
The connection is terminated if nothing is received for the given duration.
"""
is_client: bool = True
"""
Whether this is the client side of the QUIC connection.
"""
max_data: int = 1048576
"""
Connection-wide flow control limit.
"""
max_stream_data: int = 1048576
"""
Per-stream flow control limit.
"""
quic_logger: Optional[QuicLogger] = None
"""
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
"""
secrets_log_file: TextIO = None
"""
A file-like object in which to log traffic secrets.
This is useful to analyze traffic captures with Wireshark.
"""
server_name: Optional[str] = None
"""
The server name to send during the TLS handshake the Server Name Indication.
.. note:: This is only used by clients.
"""
session_ticket: Optional[SessionTicket] = None
"""
The TLS session ticket which should be used for session resumption.
"""
cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
certificate: Any = None
certificate_chain: List[Any] = field(default_factory=list)
max_datagram_frame_size: Optional[int] = None
private_key: Any = None
quantum_readiness_test: bool = False
supported_versions: List[int] = field(
default_factory=lambda: [
QuicProtocolVersion.DRAFT_27,
QuicProtocolVersion.DRAFT_26,
QuicProtocolVersion.DRAFT_25,
]
)
verify_mode: Optional[int] = None
def load_cert_chain(
self,
certfile: PathLike,
keyfile: Optional[PathLike] = None,
password: Optional[Union[bytes, str]] = None,
) -> None:
"""
Load a private key and the corresponding certificate.
"""
with open(certfile, "rb") as fp:
certificates = load_pem_x509_certificates(fp.read())
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]
if keyfile is not None:
with open(keyfile, "rb") as fp:
self.private_key = load_pem_private_key(
fp.read(),
password=password.encode("utf8")
if isinstance(password, str)
else password,
)
def load_verify_locations(
self,
cafile: Optional[str] = None,
capath: Optional[str] = None,
cadata: Optional[bytes] = None,
) -> None:
"""
Load a set of "certification authority" (CA) certificates used to
validate other peers' certificates.
"""
self.cafile = cafile
self.capath = capath
self.cadata = cadata

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,196 @@
import binascii
from typing import Optional, Tuple
from .._crypto import AEAD, CryptoError, HeaderProtection
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
from .packet import decode_packet_number, is_long_header
CIPHER_SUITES = {
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
}
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
INITIAL_SALT = binascii.unhexlify("c3eef712c72ebb5a11a7d2432bb46365bef9f502")
SAMPLE_SIZE = 16
class KeyUnavailableError(CryptoError):
pass
def derive_key_iv_hp(
cipher_suite: CipherSuite, secret: bytes
) -> Tuple[bytes, bytes, bytes]:
algorithm = cipher_suite_hash(cipher_suite)
if cipher_suite in [
CipherSuite.AES_256_GCM_SHA384,
CipherSuite.CHACHA20_POLY1305_SHA256,
]:
key_size = 32
else:
key_size = 16
return (
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
)
class CryptoContext:
def __init__(self, key_phase: int = 0) -> None:
self.aead: Optional[AEAD] = None
self.cipher_suite: Optional[CipherSuite] = None
self.hp: Optional[HeaderProtection] = None
self.key_phase = key_phase
self.secret: Optional[bytes] = None
self.version: Optional[int] = None
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int, bool]:
if self.aead is None:
raise KeyUnavailableError("Decryption key is not available")
# header protection
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
first_byte = plain_header[0]
# packet number
pn_length = (first_byte & 0x03) + 1
packet_number = decode_packet_number(
packet_number, pn_length * 8, expected_packet_number
)
# detect key phase change
crypto = self
if not is_long_header(first_byte):
key_phase = (first_byte & 4) >> 2
if key_phase != self.key_phase:
crypto = next_key_phase(self)
# payload protection
payload = crypto.aead.decrypt(
packet[len(plain_header) :], plain_header, packet_number
)
return plain_header, payload, packet_number, crypto != self
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
assert self.is_valid(), "Encryption key is not available"
# payload protection
protected_payload = self.aead.encrypt(
plain_payload, plain_header, packet_number
)
# header protection
return self.hp.apply(plain_header, protected_payload)
def is_valid(self) -> bool:
return self.aead is not None
def setup(self, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
key, iv, hp = derive_key_iv_hp(cipher_suite, secret)
self.aead = AEAD(aead_cipher_name, key, iv)
self.cipher_suite = cipher_suite
self.hp = HeaderProtection(hp_cipher_name, hp)
self.secret = secret
self.version = version
def teardown(self) -> None:
self.aead = None
self.cipher_suite = None
self.hp = None
self.secret = None
def apply_key_phase(self: CryptoContext, crypto: CryptoContext) -> None:
self.aead = crypto.aead
self.key_phase = crypto.key_phase
self.secret = crypto.secret
def next_key_phase(self: CryptoContext) -> CryptoContext:
algorithm = cipher_suite_hash(self.cipher_suite)
crypto = CryptoContext(key_phase=int(not self.key_phase))
crypto.setup(
cipher_suite=self.cipher_suite,
secret=hkdf_expand_label(
algorithm, self.secret, b"quic ku", b"", algorithm.digest_size
),
version=self.version,
)
return crypto
class CryptoPair:
def __init__(self) -> None:
self.aead_tag_size = 16
self.recv = CryptoContext()
self.send = CryptoContext()
self._update_key_requested = False
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> Tuple[bytes, bytes, int]:
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
packet, encrypted_offset, expected_packet_number
)
if update_key:
self._update_key()
return plain_header, payload, packet_number
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
if self._update_key_requested:
self._update_key()
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
if is_client:
recv_label, send_label = b"server in", b"client in"
else:
recv_label, send_label = b"client in", b"server in"
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
initial_secret = hkdf_extract(algorithm, INITIAL_SALT, cid)
self.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, recv_label, b"", algorithm.digest_size
),
version=version,
)
self.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, send_label, b"", algorithm.digest_size
),
version=version,
)
def teardown(self) -> None:
self.recv.teardown()
self.send.teardown()
def update_key(self) -> None:
self._update_key_requested = True
@property
def key_phase(self) -> int:
if self._update_key_requested:
return int(not self.recv.key_phase)
else:
return self.recv.key_phase
def _update_key(self) -> None:
apply_key_phase(self.recv, next_key_phase(self.recv))
apply_key_phase(self.send, next_key_phase(self.send))
self._update_key_requested = False

Просмотреть файл

@ -0,0 +1,112 @@
from dataclasses import dataclass
from typing import Optional
class QuicEvent:
"""
Base class for QUIC events.
"""
pass
@dataclass
class ConnectionIdIssued(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionIdRetired(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionTerminated(QuicEvent):
"""
The ConnectionTerminated event is fired when the QUIC connection is terminated.
"""
error_code: int
"The error code which was specified when closing the connection."
frame_type: Optional[int]
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
@dataclass
class DatagramFrameReceived(QuicEvent):
"""
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
"""
data: bytes
"The data which was received."
@dataclass
class HandshakeCompleted(QuicEvent):
"""
The HandshakeCompleted event is fired when the TLS handshake completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
early_data_accepted: bool
"Whether early (0-RTT) data was accepted by the remote peer."
session_resumed: bool
"Whether a TLS session was resumed."
@dataclass
class PingAcknowledged(QuicEvent):
"""
The PingAcknowledged event is fired when a PING frame is acknowledged.
"""
uid: int
"The unique ID of the PING."
@dataclass
class ProtocolNegotiated(QuicEvent):
"""
The ProtocolNegotiated event is fired when ALPN negotiation completes.
"""
alpn_protocol: Optional[str]
"The protocol which was negotiated using ALPN, or `None`."
@dataclass
class StreamDataReceived(QuicEvent):
"""
The StreamDataReceived event is fired whenever data is received on a
stream.
"""
data: bytes
"The data which was received."
end_stream: bool
"Whether the STREAM frame had the FIN bit set."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class StreamReset(QuicEvent):
"""
The StreamReset event is fired when the remote peer resets a stream.
"""
error_code: int
"The error code that triggered the reset."
stream_id: int
"The ID of the stream that was reset."

Просмотреть файл

@ -0,0 +1,266 @@
import binascii
import time
from collections import deque
from typing import Any, Deque, Dict, List, Optional, Tuple
from .packet import (
PACKET_TYPE_HANDSHAKE,
PACKET_TYPE_INITIAL,
PACKET_TYPE_MASK,
PACKET_TYPE_ONE_RTT,
PACKET_TYPE_RETRY,
PACKET_TYPE_ZERO_RTT,
QuicStreamFrame,
QuicTransportParameters,
)
from .rangeset import RangeSet
PACKET_TYPE_NAMES = {
PACKET_TYPE_INITIAL: "initial",
PACKET_TYPE_HANDSHAKE: "handshake",
PACKET_TYPE_ZERO_RTT: "0RTT",
PACKET_TYPE_ONE_RTT: "1RTT",
PACKET_TYPE_RETRY: "retry",
}
def hexdump(data: bytes) -> str:
return binascii.hexlify(data).decode("ascii")
class QuicLoggerTrace:
"""
A QUIC event trace.
Events are logged in the format defined by qlog draft-01.
See: https://quiclog.github.io/internet-drafts/draft-marx-qlog-event-definitions-quic-h3.html
"""
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
self._odcid = odcid
self._events: Deque[Tuple[float, str, str, Dict[str, Any]]] = deque()
self._vantage_point = {
"name": "aioquic",
"type": "client" if is_client else "server",
}
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> Dict:
return {
"ack_delay": str(self.encode_time(delay)),
"acked_ranges": [[str(x.start), str(x.stop - 1)] for x in ranges],
"frame_type": "ack",
}
def encode_connection_close_frame(
self, error_code: int, frame_type: Optional[int], reason_phrase: str
) -> Dict:
attrs = {
"error_code": error_code,
"error_space": "application" if frame_type is None else "transport",
"frame_type": "connection_close",
"raw_error_code": error_code,
"reason": reason_phrase,
}
if frame_type is not None:
attrs["trigger_frame_type"] = frame_type
return attrs
def encode_crypto_frame(self, frame: QuicStreamFrame) -> Dict:
return {
"frame_type": "crypto",
"length": len(frame.data),
"offset": str(frame.offset),
}
def encode_data_blocked_frame(self, limit: int) -> Dict:
return {"frame_type": "data_blocked", "limit": str(limit)}
def encode_datagram_frame(self, length: int) -> Dict:
return {"frame_type": "datagram", "length": length}
def encode_handshake_done_frame(self) -> Dict:
return {"frame_type": "handshake_done"}
def encode_max_data_frame(self, maximum: int) -> Dict:
return {"frame_type": "max_data", "maximum": str(maximum)}
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> Dict:
return {
"frame_type": "max_stream_data",
"maximum": str(maximum),
"stream_id": str(stream_id),
}
def encode_max_streams_frame(self, is_unidirectional: bool, maximum: int) -> Dict:
return {
"frame_type": "max_streams",
"maximum": str(maximum),
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_new_connection_id_frame(
self,
connection_id: bytes,
retire_prior_to: int,
sequence_number: int,
stateless_reset_token: bytes,
) -> Dict:
return {
"connection_id": hexdump(connection_id),
"frame_type": "new_connection_id",
"length": len(connection_id),
"reset_token": hexdump(stateless_reset_token),
"retire_prior_to": str(retire_prior_to),
"sequence_number": str(sequence_number),
}
def encode_new_token_frame(self, token: bytes) -> Dict:
return {
"frame_type": "new_token",
"length": len(token),
"token": hexdump(token),
}
def encode_padding_frame(self) -> Dict:
return {"frame_type": "padding"}
def encode_path_challenge_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_challenge"}
def encode_path_response_frame(self, data: bytes) -> Dict:
return {"data": hexdump(data), "frame_type": "path_response"}
def encode_ping_frame(self) -> Dict:
return {"frame_type": "ping"}
def encode_reset_stream_frame(
self, error_code: int, final_size: int, stream_id: int
) -> Dict:
return {
"error_code": error_code,
"final_size": str(final_size),
"frame_type": "reset_stream",
"stream_id": str(stream_id),
}
def encode_retire_connection_id_frame(self, sequence_number: int) -> Dict:
return {
"frame_type": "retire_connection_id",
"sequence_number": str(sequence_number),
}
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> Dict:
return {
"frame_type": "stream_data_blocked",
"limit": str(limit),
"stream_id": str(stream_id),
}
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> Dict:
return {
"frame_type": "stop_sending",
"error_code": error_code,
"stream_id": str(stream_id),
}
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> Dict:
return {
"fin": frame.fin,
"frame_type": "stream",
"length": len(frame.data),
"offset": str(frame.offset),
"stream_id": str(stream_id),
}
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> Dict:
return {
"frame_type": "streams_blocked",
"limit": str(limit),
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_time(self, seconds: float) -> int:
"""
Convert a time to integer microseconds.
"""
return int(seconds * 1000000)
def encode_transport_parameters(
self, owner: str, parameters: QuicTransportParameters
) -> Dict[str, Any]:
data: Dict[str, Any] = {"owner": owner}
for param_name, param_value in parameters.__dict__.items():
if isinstance(param_value, bool):
data[param_name] = param_value
elif isinstance(param_value, bytes):
data[param_name] = hexdump(param_value)
elif isinstance(param_value, int):
data[param_name] = param_value
return data
def log_event(self, *, category: str, event: str, data: Dict) -> None:
self._events.append((time.time(), category, event, data))
def packet_type(self, packet_type: int) -> str:
return PACKET_TYPE_NAMES.get(packet_type & PACKET_TYPE_MASK, "1RTT")
def to_dict(self) -> Dict[str, Any]:
"""
Return the trace as a dictionary which can be written as JSON.
"""
if self._events:
reference_time = self._events[0][0]
else:
reference_time = 0.0
return {
"configuration": {"time_units": "us"},
"common_fields": {
"ODCID": hexdump(self._odcid),
"reference_time": str(self.encode_time(reference_time)),
},
"event_fields": ["relative_time", "category", "event_type", "data"],
"events": list(
map(
lambda event: (
str(self.encode_time(event[0] - reference_time)),
event[1],
event[2],
event[3],
),
self._events,
)
),
"vantage_point": self._vantage_point,
}
class QuicLogger:
"""
A QUIC event logger.
Serves as a container for traces in the format defined by qlog draft-01.
See: https://quiclog.github.io/internet-drafts/draft-marx-qlog-main-schema.html
"""
def __init__(self) -> None:
self._traces: List[QuicLoggerTrace] = []
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
self._traces.append(trace)
return trace
def end_trace(self, trace: QuicLoggerTrace) -> None:
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
def to_dict(self) -> Dict[str, Any]:
"""
Return the traces as a dictionary which can be written as JSON.
"""
return {
"qlog_version": "draft-01",
"traces": [trace.to_dict() for trace in self._traces],
}

Просмотреть файл

@ -0,0 +1,508 @@
import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import IntEnum
from typing import List, Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ..buffer import Buffer
from ..tls import pull_block, push_block
from .rangeset import RangeSet
PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20
PACKET_TYPE_INITIAL = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x00
PACKET_TYPE_ZERO_RTT = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x10
PACKET_TYPE_HANDSHAKE = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x20
PACKET_TYPE_RETRY = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x30
PACKET_TYPE_ONE_RTT = PACKET_FIXED_BIT
PACKET_TYPE_MASK = 0xF0
CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY = binascii.unhexlify("4d32ecdb2a2133c841e4043df27d4430")
RETRY_AEAD_NONCE = binascii.unhexlify("4d1611d05513a552c587d575")
RETRY_INTEGRITY_TAG_SIZE = 16
class QuicErrorCode(IntEnum):
NO_ERROR = 0x0
INTERNAL_ERROR = 0x1
SERVER_BUSY = 0x2
FLOW_CONTROL_ERROR = 0x3
STREAM_LIMIT_ERROR = 0x4
STREAM_STATE_ERROR = 0x5
FINAL_SIZE_ERROR = 0x6
FRAME_ENCODING_ERROR = 0x7
TRANSPORT_PARAMETER_ERROR = 0x8
CONNECTION_ID_LIMIT_ERROR = 0x9
PROTOCOL_VIOLATION = 0xA
INVALID_TOKEN = 0xB
CRYPTO_BUFFER_EXCEEDED = 0xD
CRYPTO_ERROR = 0x100
class QuicProtocolVersion(IntEnum):
NEGOTIATION = 0
DRAFT_25 = 0xFF000019
DRAFT_26 = 0xFF00001A
DRAFT_27 = 0xFF00001B
@dataclass
class QuicHeader:
is_long_header: bool
version: Optional[int]
packet_type: int
destination_cid: bytes
source_cid: bytes
token: bytes = b""
integrity_tag: bytes = b""
rest_length: int = 0
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
"""
Recover a packet number from a truncated packet number.
See: Appendix A - Sample Packet Number Decoding Algorithm
"""
window = 1 << num_bits
half_window = window // 2
candidate = (expected & ~(window - 1)) | truncated
if candidate <= expected - half_window and candidate < (1 << 62) - window:
return candidate + window
elif candidate > expected + half_window and candidate >= window:
return candidate - window
else:
return candidate
def get_retry_integrity_tag(
packet_without_tag: bytes, original_destination_cid: bytes
) -> bytes:
"""
Calculate the integrity tag for a RETRY packet.
"""
# build Retry pseudo packet
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
buf.push_uint8(len(original_destination_cid))
buf.push_bytes(original_destination_cid)
buf.push_bytes(packet_without_tag)
assert buf.eof()
# run AES-128-GCM
aead = AESGCM(RETRY_AEAD_KEY)
integrity_tag = aead.encrypt(RETRY_AEAD_NONCE, b"", buf.data)
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
return integrity_tag
def get_spin_bit(first_byte: int) -> bool:
return bool(first_byte & PACKET_SPIN_BIT)
def is_long_header(first_byte: int) -> bool:
return bool(first_byte & PACKET_LONG_HEADER)
def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader:
first_byte = buf.pull_uint8()
integrity_tag = b""
token = b""
if is_long_header(first_byte):
# long header packet
version = buf.pull_uint32()
destination_cid_length = buf.pull_uint8()
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(
"Destination CID is too long (%d bytes)" % destination_cid_length
)
destination_cid = buf.pull_bytes(destination_cid_length)
source_cid_length = buf.pull_uint8()
if source_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError("Source CID is too long (%d bytes)" % source_cid_length)
source_cid = buf.pull_bytes(source_cid_length)
if version == QuicProtocolVersion.NEGOTIATION:
# version negotiation
packet_type = None
rest_length = buf.capacity - buf.tell()
else:
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
packet_type = first_byte & PACKET_TYPE_MASK
if packet_type == PACKET_TYPE_INITIAL:
token_length = buf.pull_uint_var()
token = buf.pull_bytes(token_length)
rest_length = buf.pull_uint_var()
elif packet_type == PACKET_TYPE_RETRY:
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
token = buf.pull_bytes(token_length)
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
rest_length = 0
else:
rest_length = buf.pull_uint_var()
return QuicHeader(
is_long_header=True,
version=version,
packet_type=packet_type,
destination_cid=destination_cid,
source_cid=source_cid,
token=token,
integrity_tag=integrity_tag,
rest_length=rest_length,
)
else:
# short header packet
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
packet_type = first_byte & PACKET_TYPE_MASK
destination_cid = buf.pull_bytes(host_cid_length)
return QuicHeader(
is_long_header=False,
version=None,
packet_type=packet_type,
destination_cid=destination_cid,
source_cid=b"",
token=b"",
rest_length=buf.capacity - buf.tell(),
)
def encode_quic_retry(
version: int,
source_cid: bytes,
destination_cid: bytes,
original_destination_cid: bytes,
retry_token: bytes,
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ len(retry_token)
+ RETRY_INTEGRITY_TAG_SIZE
)
buf.push_uint8(PACKET_TYPE_RETRY)
buf.push_uint32(version)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
buf.push_bytes(retry_token)
buf.push_bytes(get_retry_integrity_tag(buf.data, original_destination_cid))
assert buf.eof()
return buf.data
def encode_quic_version_negotiation(
source_cid: bytes, destination_cid: bytes, supported_versions: List[int]
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ 4 * len(supported_versions)
)
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
for version in supported_versions:
buf.push_uint32(version)
return buf.data
# TLS EXTENSION
@dataclass
class QuicPreferredAddress:
ipv4_address: Optional[Tuple[str, int]]
ipv6_address: Optional[Tuple[str, int]]
connection_id: bytes
stateless_reset_token: bytes
@dataclass
class QuicTransportParameters:
original_connection_id: Optional[bytes] = None
idle_timeout: Optional[int] = None
stateless_reset_token: Optional[bytes] = None
max_packet_size: Optional[int] = None
initial_max_data: Optional[int] = None
initial_max_stream_data_bidi_local: Optional[int] = None
initial_max_stream_data_bidi_remote: Optional[int] = None
initial_max_stream_data_uni: Optional[int] = None
initial_max_streams_bidi: Optional[int] = None
initial_max_streams_uni: Optional[int] = None
ack_delay_exponent: Optional[int] = None
max_ack_delay: Optional[int] = None
disable_active_migration: Optional[bool] = False
preferred_address: Optional[QuicPreferredAddress] = None
active_connection_id_limit: Optional[int] = None
max_datagram_frame_size: Optional[int] = None
quantum_readiness: Optional[bytes] = None
PARAMS = {
0: ("original_connection_id", bytes),
1: ("idle_timeout", int),
2: ("stateless_reset_token", bytes),
3: ("max_packet_size", int),
4: ("initial_max_data", int),
5: ("initial_max_stream_data_bidi_local", int),
6: ("initial_max_stream_data_bidi_remote", int),
7: ("initial_max_stream_data_uni", int),
8: ("initial_max_streams_bidi", int),
9: ("initial_max_streams_uni", int),
10: ("ack_delay_exponent", int),
11: ("max_ack_delay", int),
12: ("disable_active_migration", bool),
13: ("preferred_address", QuicPreferredAddress),
14: ("active_connection_id_limit", int),
32: ("max_datagram_frame_size", int),
3127: ("quantum_readiness", bytes),
}
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
ipv4_address = None
ipv4_host = buf.pull_bytes(4)
ipv4_port = buf.pull_uint16()
if ipv4_host != bytes(4):
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
ipv6_address = None
ipv6_host = buf.pull_bytes(16)
ipv6_port = buf.pull_uint16()
if ipv6_host != bytes(16):
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
connection_id_length = buf.pull_uint8()
connection_id = buf.pull_bytes(connection_id_length)
stateless_reset_token = buf.pull_bytes(16)
return QuicPreferredAddress(
ipv4_address=ipv4_address,
ipv6_address=ipv6_address,
connection_id=connection_id,
stateless_reset_token=stateless_reset_token,
)
def push_quic_preferred_address(
buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
if preferred_address.ipv4_address is not None:
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
buf.push_uint16(preferred_address.ipv4_address[1])
else:
buf.push_bytes(bytes(6))
if preferred_address.ipv6_address is not None:
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
buf.push_uint16(preferred_address.ipv6_address[1])
else:
buf.push_bytes(bytes(18))
buf.push_uint8(len(preferred_address.connection_id))
buf.push_bytes(preferred_address.connection_id)
buf.push_bytes(preferred_address.stateless_reset_token)
def pull_quic_transport_parameters(
buf: Buffer, protocol_version: int
) -> QuicTransportParameters:
params = QuicTransportParameters()
if protocol_version < QuicProtocolVersion.DRAFT_27:
with pull_block(buf, 2) as length:
end = buf.tell() + length
while buf.tell() < end:
param_id = buf.pull_uint16()
param_len = buf.pull_uint16()
param_start = buf.tell()
if param_id in PARAMS:
# parse known parameter
param_name, param_type = PARAMS[param_id]
if param_type == int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type == bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type == QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
else:
setattr(params, param_name, True)
else:
# skip unknown parameter
buf.pull_bytes(param_len)
assert buf.tell() == param_start + param_len
else:
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# parse known parameter
param_name, param_type = PARAMS[param_id]
if param_type == int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type == bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type == QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
else:
setattr(params, param_name, True)
else:
# skip unknown parameter
buf.pull_bytes(param_len)
assert buf.tell() == param_start + param_len
return params
def push_quic_transport_parameters(
buf: Buffer, params: QuicTransportParameters, protocol_version: int
) -> None:
if protocol_version < QuicProtocolVersion.DRAFT_27:
with push_block(buf, 2):
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
buf.push_uint16(param_id)
with push_block(buf, 2):
if param_type == int:
buf.push_uint_var(param_value)
elif param_type == bytes:
buf.push_bytes(param_value)
elif param_type == QuicPreferredAddress:
push_quic_preferred_address(buf, param_value)
else:
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
param_buf = Buffer(capacity=65536)
if param_type == int:
param_buf.push_uint_var(param_value)
elif param_type == bytes:
param_buf.push_bytes(param_value)
elif param_type == QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
# FRAMES
class QuicFrameType(IntEnum):
PADDING = 0x00
PING = 0x01
ACK = 0x02
ACK_ECN = 0x03
RESET_STREAM = 0x04
STOP_SENDING = 0x05
CRYPTO = 0x06
NEW_TOKEN = 0x07
STREAM_BASE = 0x08
MAX_DATA = 0x10
MAX_STREAM_DATA = 0x11
MAX_STREAMS_BIDI = 0x12
MAX_STREAMS_UNI = 0x13
DATA_BLOCKED = 0x14
STREAM_DATA_BLOCKED = 0x15
STREAMS_BLOCKED_BIDI = 0x16
STREAMS_BLOCKED_UNI = 0x17
NEW_CONNECTION_ID = 0x18
RETIRE_CONNECTION_ID = 0x19
PATH_CHALLENGE = 0x1A
PATH_RESPONSE = 0x1B
TRANSPORT_CLOSE = 0x1C
APPLICATION_CLOSE = 0x1D
HANDSHAKE_DONE = 0x1E
DATAGRAM = 0x30
DATAGRAM_WITH_LENGTH = 0x31
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.PADDING,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
PROBING_FRAME_TYPES = frozenset(
[
QuicFrameType.PATH_CHALLENGE,
QuicFrameType.PATH_RESPONSE,
QuicFrameType.PADDING,
QuicFrameType.NEW_CONNECTION_ID,
]
)
@dataclass
class QuicStreamFrame:
data: bytes = b""
fin: bool = False
offset: int = 0
def pull_ack_frame(buf: Buffer) -> Tuple[RangeSet, int]:
rangeset = RangeSet()
end = buf.pull_uint_var() # largest acknowledged
delay = buf.pull_uint_var()
ack_range_count = buf.pull_uint_var()
ack_count = buf.pull_uint_var() # first ack range
rangeset.add(end - ack_count, end + 1)
end -= ack_count
for _ in range(ack_range_count):
end -= buf.pull_uint_var() + 2
ack_count = buf.pull_uint_var()
rangeset.add(end - ack_count, end + 1)
end -= ack_count
return rangeset, delay
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
ranges = len(rangeset)
index = ranges - 1
r = rangeset[index]
buf.push_uint_var(r.stop - 1)
buf.push_uint_var(delay)
buf.push_uint_var(index)
buf.push_uint_var(r.stop - 1 - r.start)
start = r.start
while index > 0:
index -= 1
r = rangeset[index]
buf.push_uint_var(start - r.stop - 1)
buf.push_uint_var(r.stop - r.start - 1)
start = r.start
return ranges

Просмотреть файл

@ -0,0 +1,366 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from ..buffer import Buffer, size_uint_var
from ..tls import Epoch
from .crypto import CryptoPair
from .logger import QuicLoggerTrace
from .packet import (
NON_ACK_ELICITING_FRAME_TYPES,
NON_IN_FLIGHT_FRAME_TYPES,
PACKET_NUMBER_MAX_SIZE,
PACKET_TYPE_HANDSHAKE,
PACKET_TYPE_INITIAL,
PACKET_TYPE_MASK,
QuicFrameType,
is_long_header,
)
PACKET_MAX_SIZE = 1280
PACKET_LENGTH_SEND_SIZE = 2
PACKET_NUMBER_SEND_SIZE = 2
QuicDeliveryHandler = Callable[..., None]
class QuicDeliveryState(Enum):
ACKED = 0
LOST = 1
EXPIRED = 2
@dataclass
class QuicSentPacket:
epoch: Epoch
in_flight: bool
is_ack_eliciting: bool
is_crypto_packet: bool
packet_number: int
packet_type: int
sent_time: Optional[float] = None
sent_bytes: int = 0
delivery_handlers: List[Tuple[QuicDeliveryHandler, Any]] = field(
default_factory=list
)
quic_logger_frames: List[Dict] = field(default_factory=list)
class QuicPacketBuilderStop(Exception):
pass
class QuicPacketBuilder:
"""
Helper for building QUIC packets.
"""
def __init__(
self,
*,
host_cid: bytes,
peer_cid: bytes,
version: int,
is_client: bool,
packet_number: int = 0,
peer_token: bytes = b"",
quic_logger: Optional[QuicLoggerTrace] = None,
spin_bit: bool = False,
):
self.max_flight_bytes: Optional[int] = None
self.max_total_bytes: Optional[int] = None
self.quic_logger_frames: Optional[List[Dict]] = None
self._host_cid = host_cid
self._is_client = is_client
self._peer_cid = peer_cid
self._peer_token = peer_token
self._quic_logger = quic_logger
self._spin_bit = spin_bit
self._version = version
# assembled datagrams and packets
self._datagrams: List[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._packets: List[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
# current packet
self._header_size = 0
self._packet: Optional[QuicSentPacket] = None
self._packet_crypto: Optional[CryptoPair] = None
self._packet_long_header = False
self._packet_number = packet_number
self._packet_start = 0
self._packet_type = 0
self._buffer = Buffer(PACKET_MAX_SIZE)
self._buffer_capacity = PACKET_MAX_SIZE
self._flight_capacity = PACKET_MAX_SIZE
@property
def packet_is_empty(self) -> bool:
"""
Returns `True` if the current packet is empty.
"""
assert self._packet is not None
packet_size = self._buffer.tell() - self._packet_start
return packet_size <= self._header_size
@property
def packet_number(self) -> int:
"""
Returns the packet number for the next packet.
"""
return self._packet_number
@property
def remaining_buffer_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._buffer_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
@property
def remaining_flight_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._flight_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
def flush(self) -> Tuple[List[bytes], List[QuicSentPacket]]:
"""
Returns the assembled datagrams.
"""
if self._packet is not None:
self._end_packet()
self._flush_current_datagram()
datagrams = self._datagrams
packets = self._packets
self._datagrams = []
self._packets = []
return datagrams, packets
def start_frame(
self,
frame_type: int,
capacity: int = 1,
handler: Optional[QuicDeliveryHandler] = None,
handler_args: Sequence[Any] = [],
) -> Buffer:
"""
Starts a new frame.
"""
if self.remaining_buffer_space < capacity or (
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
and self.remaining_flight_space < capacity
):
raise QuicPacketBuilderStop
self._buffer.push_uint_var(frame_type)
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
self._packet.is_ack_eliciting = True
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
self._packet.in_flight = True
if frame_type == QuicFrameType.CRYPTO:
self._packet.is_crypto_packet = True
if handler is not None:
self._packet.delivery_handlers.append((handler, handler_args))
return self._buffer
def start_packet(self, packet_type: int, crypto: CryptoPair) -> None:
"""
Starts a new packet.
"""
buf = self._buffer
# finish previous datagram
if self._packet is not None:
self._end_packet()
# if there is too little space remaining, start a new datagram
# FIXME: the limit is arbitrary!
packet_start = buf.tell()
if self._buffer_capacity - packet_start < 128:
self._flush_current_datagram()
packet_start = 0
# initialize datagram if needed
if self._datagram_init:
if self.max_total_bytes is not None:
remaining_total_bytes = self.max_total_bytes - self._total_bytes
if remaining_total_bytes < self._buffer_capacity:
self._buffer_capacity = remaining_total_bytes
self._flight_capacity = self._buffer_capacity
if self.max_flight_bytes is not None:
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
if remaining_flight_bytes < self._flight_capacity:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
# calculate header size
packet_long_header = is_long_header(packet_type)
if packet_long_header:
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL:
token_length = len(self._peer_token)
header_size += size_uint_var(token_length) + token_length
else:
header_size = 3 + len(self._peer_cid)
# check we have enough space
if packet_start + header_size >= self._buffer_capacity:
raise QuicPacketBuilderStop
# determine ack epoch
if packet_type == PACKET_TYPE_INITIAL:
epoch = Epoch.INITIAL
elif packet_type == PACKET_TYPE_HANDSHAKE:
epoch = Epoch.HANDSHAKE
else:
epoch = Epoch.ONE_RTT
self._header_size = header_size
self._packet = QuicSentPacket(
epoch=epoch,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=self._packet_number,
packet_type=packet_type,
)
self._packet_crypto = crypto
self._packet_long_header = packet_long_header
self._packet_start = packet_start
self._packet_type = packet_type
self.quic_logger_frames = self._packet.quic_logger_frames
buf.seek(self._packet_start + self._header_size)
def _end_packet(self) -> None:
"""
Ends the current packet.
"""
buf = self._buffer
packet_size = buf.tell() - self._packet_start
if packet_size > self._header_size:
# pad initial datagram
if (
self._is_client
and self._packet_type == PACKET_TYPE_INITIAL
and self._packet.is_crypto_packet
):
if self.remaining_flight_space:
buf.push_bytes(bytes(self.remaining_flight_space))
packet_size = buf.tell() - self._packet_start
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# write header
if self._packet_long_header:
length = (
packet_size
- self._header_size
+ PACKET_NUMBER_SEND_SIZE
+ self._packet_crypto.aead_tag_size
)
buf.seek(self._packet_start)
buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1))
buf.push_uint32(self._version)
buf.push_uint8(len(self._peer_cid))
buf.push_bytes(self._peer_cid)
buf.push_uint8(len(self._host_cid))
buf.push_bytes(self._host_cid)
if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL:
buf.push_uint_var(len(self._peer_token))
buf.push_bytes(self._peer_token)
buf.push_uint16(length | 0x4000)
buf.push_uint16(self._packet_number & 0xFFFF)
else:
buf.seek(self._packet_start)
buf.push_uint8(
self._packet_type
| (self._spin_bit << 5)
| (self._packet_crypto.key_phase << 2)
| (PACKET_NUMBER_SEND_SIZE - 1)
)
buf.push_bytes(self._peer_cid)
buf.push_uint16(self._packet_number & 0xFFFF)
# check whether we need padding
padding_size = (
PACKET_NUMBER_MAX_SIZE
- PACKET_NUMBER_SEND_SIZE
+ self._header_size
- packet_size
)
if padding_size > 0:
buf.seek(self._packet_start + packet_size)
buf.push_bytes(bytes(padding_size))
packet_size += padding_size
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# encrypt in place
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
buf.seek(self._packet_start)
buf.push_bytes(
self._packet_crypto.encrypt_packet(
plain[0 : self._header_size],
plain[self._header_size : packet_size],
self._packet_number,
)
)
self._packet.sent_bytes = buf.tell() - self._packet_start
self._packets.append(self._packet)
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes
# short header packets cannot be coallesced, we need a new datagram
if not self._packet_long_header:
self._flush_current_datagram()
self._packet_number += 1
else:
# "cancel" the packet
buf.seek(self._packet_start)
self._packet = None
self.quic_logger_frames = None
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
self._datagram_init = True
self._buffer.seek(0)

Просмотреть файл

@ -0,0 +1,98 @@
from collections.abc import Sequence
from typing import Any, Iterable, List, Optional
class RangeSet(Sequence):
def __init__(self, ranges: Iterable[range] = []):
self.__ranges: List[range] = []
for r in ranges:
assert r.step == 1
self.add(r.start, r.stop)
def add(self, start: int, stop: Optional[int] = None) -> None:
if stop is None:
stop = start + 1
assert stop > start
for i, r in enumerate(self.__ranges):
# the added range is entirely before current item, insert here
if stop < r.start:
self.__ranges.insert(i, range(start, stop))
return
# the added range is entirely after current item, keep looking
if start > r.stop:
continue
# the added range touches the current item, merge it
start = min(start, r.start)
stop = max(stop, r.stop)
while i < len(self.__ranges) - 1 and self.__ranges[i + 1].start <= stop:
stop = max(self.__ranges[i + 1].stop, stop)
self.__ranges.pop(i + 1)
self.__ranges[i] = range(start, stop)
return
# the added range is entirely after all existing items, append it
self.__ranges.append(range(start, stop))
def bounds(self) -> range:
return range(self.__ranges[0].start, self.__ranges[-1].stop)
def shift(self) -> range:
return self.__ranges.pop(0)
def subtract(self, start: int, stop: int) -> None:
assert stop > start
i = 0
while i < len(self.__ranges):
r = self.__ranges[i]
# the removed range is entirely before current item, stop here
if stop <= r.start:
return
# the removed range is entirely after current item, keep looking
if start >= r.stop:
i += 1
continue
# the removed range completely covers the current item, remove it
if start <= r.start and stop >= r.stop:
self.__ranges.pop(i)
continue
# the removed range touches the current item
if start > r.start:
self.__ranges[i] = range(r.start, start)
if stop < r.stop:
self.__ranges.insert(i + 1, range(stop, r.stop))
else:
self.__ranges[i] = range(stop, r.stop)
i += 1
def __bool__(self) -> bool:
raise NotImplementedError
def __contains__(self, val: Any) -> bool:
for r in self.__ranges:
if val in r:
return True
return False
def __eq__(self, other: object) -> bool:
if not isinstance(other, RangeSet):
return NotImplemented
return self.__ranges == other.__ranges
def __getitem__(self, key: Any) -> range:
return self.__ranges[key]
def __len__(self) -> int:
return len(self.__ranges)
def __repr__(self) -> str:
return "RangeSet({})".format(repr(self.__ranges))

Просмотреть файл

@ -0,0 +1,497 @@
import math
from typing import Callable, Dict, Iterable, List, Optional
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
from .rangeset import RangeSet
# loss detection
K_PACKET_THRESHOLD = 3
K_INITIAL_RTT = 0.5 # seconds
K_GRANULARITY = 0.001 # seconds
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
# congestion control
K_MAX_DATAGRAM_SIZE = 1280
K_INITIAL_WINDOW = 10 * K_MAX_DATAGRAM_SIZE
K_MINIMUM_WINDOW = 2 * K_MAX_DATAGRAM_SIZE
K_LOSS_REDUCTION_FACTOR = 0.5
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: Optional[float] = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: Optional[float] = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: Optional[float] = None
self.sent_packets: Dict[int, QuicSentPacket] = {}
class QuicPacketPacer:
def __init__(self) -> None:
self.bucket_max: float = 0.0
self.bucket_time: float = 0.0
self.evaluation_time: float = 0.0
self.packet_time: Optional[float] = None
def next_send_time(self, now: float) -> float:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time <= 0:
return now + self.packet_time
return None
def update_after_send(self, now: float) -> None:
if self.packet_time is not None:
self.update_bucket(now=now)
if self.bucket_time < self.packet_time:
self.bucket_time = 0.0
else:
self.bucket_time -= self.packet_time
def update_bucket(self, now: float) -> None:
if now > self.evaluation_time:
self.bucket_time = min(
self.bucket_time + (now - self.evaluation_time), self.bucket_max
)
self.evaluation_time = now
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None:
pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND)
self.packet_time = max(
K_MICRO_SECOND, min(K_MAX_DATAGRAM_SIZE / pacing_rate, K_SECOND)
)
self.bucket_max = (
max(
2 * K_MAX_DATAGRAM_SIZE,
min(congestion_window // 4, 16 * K_MAX_DATAGRAM_SIZE),
)
/ pacing_rate
)
if self.bucket_time > self.bucket_max:
self.bucket_time = self.bucket_max
class QuicCongestionControl:
"""
New Reno congestion control.
"""
def __init__(self) -> None:
self.bytes_in_flight = 0
self.congestion_window = K_INITIAL_WINDOW
self._congestion_recovery_start_time = 0.0
self._congestion_stash = 0
self._rtt_monitor = QuicRttMonitor()
self.ssthresh: Optional[int] = None
def on_packet_acked(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
# don't increase window in congestion recovery
if packet.sent_time <= self._congestion_recovery_start_time:
return
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
self._congestion_stash += packet.sent_bytes
count = self._congestion_stash // self.congestion_window
if count:
self._congestion_stash -= count * self.congestion_window
self.congestion_window += count * K_MAX_DATAGRAM_SIZE
def on_packet_sent(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
self.congestion_window = max(
int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW
)
self.ssthresh = self.congestion_window
# TODO : collapse congestion window if persistent congestion
def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
latest_rtt, now
):
self.ssthresh = self.congestion_window
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
is_client_without_1rtt: bool,
send_probe: Callable[[], None],
quic_logger: Optional[QuicLoggerTrace] = None,
) -> None:
self.is_client_without_1rtt = is_client_without_1rtt
self.max_ack_delay = 0.025
self.spaces: List[QuicPacketSpace] = []
# callbacks
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = QuicCongestionControl()
self._pacer = QuicPacketPacer()
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
if self._quic_logger is not None:
self._log_metrics_updated()
def get_earliest_loss_space(self) -> Optional[QuicPacketSpace]:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self.get_earliest_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
self.is_client_without_1rtt
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
if not self._rtt_initialized:
timeout = 2 * K_INITIAL_RTT * (2 ** self._pto_count)
else:
timeout = self.get_probe_timeout() * (2 ** self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
space: QuicPacketSpace,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds().stop - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(latest_rtt, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(space, now=now)
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
self._pto_count = 0
def on_loss_detection_timeout(self, now: float) -> None:
loss_space = self.get_earliest_loss_space()
if loss_space is not None:
self._detect_loss(loss_space, now=now)
else:
self._pto_count += 1
# reschedule some data
for space in self.spaces:
self._on_packets_lost(
tuple(
filter(
lambda i: i.is_crypto_packet, space.sent_packets.values()
)
),
space=space,
now=now,
)
self._send_probe()
def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def _detect_loss(self, space: QuicPacketSpace, now: float) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else K_INITIAL_RTT
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(lost_packets, space=space, now=now)
def _log_metrics_updated(self, log_rtt=False) -> None:
data = {
"bytes_in_flight": self._cc.bytes_in_flight,
"cwnd": self._cc.congestion_window,
}
if self._cc.ssthresh is not None:
data["ssthresh"] = self._cc.ssthresh
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": str(packet.packet_number),
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(lost_packets_cc, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5
self._filtered_min: Optional[float] = None
self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]
def add_rtt(self, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1
if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True
if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample
def is_rtt_increasing(self, rtt: float, now: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt)
self._sample_time = now
if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max
delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False

Просмотреть файл

@ -0,0 +1,39 @@
import ipaddress
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from .connection import NetworkAddress
def encode_address(addr: NetworkAddress) -> bytes:
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
class QuicRetryTokenHandler:
def __init__(self) -> None:
self._key = rsa.generate_private_key(
public_exponent=65537, key_size=1024, backend=default_backend()
)
def create_token(self, addr: NetworkAddress, destination_cid: bytes) -> bytes:
retry_message = encode_address(addr) + b"|" + destination_cid
return self._key.public_key().encrypt(
retry_message,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)
def validate_token(self, addr: NetworkAddress, token: bytes) -> bytes:
retry_message = self._key.decrypt(
token,
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None
),
)
encoded_addr, original_connection_id = retry_message.split(b"|", maxsplit=1)
if encoded_addr != encode_address(addr):
raise ValueError("Remote address does not match.")
return original_connection_id

Просмотреть файл

@ -0,0 +1,214 @@
from typing import Optional
from . import events
from .packet import QuicStreamFrame
from .packet_builder import QuicDeliveryState
from .rangeset import RangeSet
class QuicStream:
def __init__(
self,
stream_id: Optional[int] = None,
max_stream_data_local: int = 0,
max_stream_data_remote: int = 0,
) -> None:
self.is_blocked = False
self.max_stream_data_local = max_stream_data_local
self.max_stream_data_local_sent = max_stream_data_local
self.max_stream_data_remote = max_stream_data_remote
self.send_buffer_is_empty = True
self._recv_buffer = bytearray()
self._recv_buffer_fin: Optional[int] = None
self._recv_buffer_start = 0 # the offset for the start of the buffer
self._recv_highest = 0 # the highest offset ever seen
self._recv_ranges = RangeSet()
self._send_acked = RangeSet()
self._send_buffer = bytearray()
self._send_buffer_fin: Optional[int] = None
self._send_buffer_start = 0 # the offset for the start of the buffer
self._send_buffer_stop = 0 # the offset for the stop of the buffer
self._send_highest = 0
self._send_pending = RangeSet()
self._send_pending_eof = False
self.__stream_id = stream_id
@property
def stream_id(self) -> Optional[int]:
return self.__stream_id
# reader
def add_frame(self, frame: QuicStreamFrame) -> Optional[events.StreamDataReceived]:
"""
Add a frame of received data.
"""
pos = frame.offset - self._recv_buffer_start
count = len(frame.data)
frame_end = frame.offset + count
# we should receive no more data beyond FIN!
if self._recv_buffer_fin is not None and frame_end > self._recv_buffer_fin:
raise Exception("Data received beyond FIN")
if frame.fin:
self._recv_buffer_fin = frame_end
if frame_end > self._recv_highest:
self._recv_highest = frame_end
# fast path: new in-order chunk
if pos == 0 and count and not self._recv_buffer:
self._recv_buffer_start += count
return events.StreamDataReceived(
data=frame.data, end_stream=frame.fin, stream_id=self.__stream_id
)
# discard duplicate data
if pos < 0:
frame.data = frame.data[-pos:]
frame.offset -= pos
pos = 0
# marked received range
if frame_end > frame.offset:
self._recv_ranges.add(frame.offset, frame_end)
# add new data
gap = pos - len(self._recv_buffer)
if gap > 0:
self._recv_buffer += bytearray(gap)
self._recv_buffer[pos : pos + count] = frame.data
# return data from the front of the buffer
data = self._pull_data()
end_stream = self._recv_buffer_start == self._recv_buffer_fin
if data or end_stream:
return events.StreamDataReceived(
data=data, end_stream=end_stream, stream_id=self.__stream_id
)
else:
return None
def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
"""
try:
has_data_to_read = self._recv_ranges[0].start == self._recv_buffer_start
except IndexError:
has_data_to_read = False
if not has_data_to_read:
return b""
r = self._recv_ranges.shift()
pos = r.stop - r.start
data = bytes(self._recv_buffer[:pos])
del self._recv_buffer[:pos]
self._recv_buffer_start = r.stop
return data
# writer
@property
def next_send_offset(self) -> int:
"""
The offset for the next frame to send.
This is used to determine the space needed for the frame's `offset` field.
"""
try:
return self._send_pending[0].start
except IndexError:
return self._send_buffer_stop
def get_frame(
self, max_size: int, max_offset: Optional[int] = None
) -> Optional[QuicStreamFrame]:
"""
Get a frame of data to send.
"""
# get the first pending data range
try:
r = self._send_pending[0]
except IndexError:
if self._send_pending_eof:
# FIN only
self._send_pending_eof = False
return QuicStreamFrame(fin=True, offset=self._send_buffer_fin)
self.send_buffer_is_empty = True
return None
# apply flow control
start = r.start
stop = min(r.stop, start + max_size)
if max_offset is not None and stop > max_offset:
stop = max_offset
if stop <= start:
return None
# create frame
frame = QuicStreamFrame(
data=bytes(
self._send_buffer[
start - self._send_buffer_start : stop - self._send_buffer_start
]
),
offset=start,
)
self._send_pending.subtract(start, stop)
# track the highest offset ever sent
if stop > self._send_highest:
self._send_highest = stop
# if the buffer is empty and EOF was written, set the FIN bit
if self._send_buffer_fin == stop:
frame.fin = True
self._send_pending_eof = False
return frame
def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.send_buffer_is_empty = False
if delivery == QuicDeliveryState.ACKED:
if stop > start:
self._send_acked.add(start, stop)
first_range = self._send_acked[0]
if first_range.start == self._send_buffer_start:
size = first_range.stop - first_range.start
self._send_acked.shift()
self._send_buffer_start += size
del self._send_buffer[:size]
else:
if stop > start:
self._send_pending.add(start, stop)
if stop == self._send_buffer_fin:
self.send_buffer_empty = False
self._send_pending_eof = True
def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
"""
assert self._send_buffer_fin is None, "cannot call write() after FIN"
size = len(data)
if size:
self.send_buffer_is_empty = False
self._send_pending.add(
self._send_buffer_stop, self._send_buffer_stop + size
)
self._send_buffer += data
self._send_buffer_stop += size
if end_stream:
self.send_buffer_is_empty = False
self._send_buffer_fin = self._send_buffer_stop
self._send_pending_eof = True

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/initial_client.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/initial_server.bin поставляемый Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,99 @@
Certificate:
Data:
Version: 3 (0x2)
Serial Number:
cb:2d:80:99:5a:69:52:5b
Signature Algorithm: sha256WithRSAEncryption
Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
Validity
Not Before: Aug 29 14:23:16 2018 GMT
Not After : Aug 26 14:23:16 2028 GMT
Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
Public-Key: (3072 bit)
Modulus:
00:97:ed:55:41:ba:36:17:95:db:71:1c:d3:e1:61:
ac:58:73:e3:c6:96:cf:2b:1f:b8:08:f5:9d:4b:4b:
c7:30:f6:b8:0b:b3:52:72:a0:bb:c9:4d:3b:8e:df:
22:8e:01:57:81:c9:92:73:cc:00:c6:ec:70:b0:3a:
17:40:c1:df:f2:8c:36:4c:c4:a7:81:e7:b6:24:68:
e2:a0:7e:35:07:2f:a0:5b:f9:45:46:f7:1e:f0:46:
11:fe:ca:1a:3c:50:f1:26:a9:5f:9c:22:9c:f8:41:
e1:df:4f:12:95:19:2f:5c:90:01:17:6e:7e:3e:7d:
cf:e9:09:af:25:f8:f8:42:77:2d:6d:5f:36:f2:78:
1e:7d:4a:87:68:63:6c:06:71:1b:8d:fa:25:fe:d4:
d3:f5:a5:17:b1:ef:ea:17:cb:54:c8:27:99:80:cb:
3c:45:f1:2c:52:1c:dd:1f:51:45:20:50:1e:5e:ab:
57:73:1b:41:78:96:de:84:a4:7a:dd:8f:30:85:36:
58:79:76:a0:d2:61:c8:1b:a9:94:99:63:c6:ee:f8:
14:bf:b4:52:56:31:97:fa:eb:ac:53:9e:95:ce:4c:
c4:5a:4a:b7:ca:03:27:5b:35:57:ce:02:dc:ec:ca:
69:f8:8a:5a:39:cb:16:20:15:03:24:61:6c:f4:7a:
fc:b6:48:e5:59:10:5c:49:d0:23:9f:fb:71:5e:3a:
e9:68:9f:34:72:80:27:b6:3f:4c:b1:d9:db:63:7f:
67:68:4a:6e:11:f8:e8:c0:f4:5a:16:39:53:0b:68:
de:77:fa:45:e7:f8:91:cd:78:cd:28:94:97:71:54:
fb:cf:f0:37:de:c9:26:c5:dc:1b:9e:89:6d:09:ac:
c8:44:71:cb:6d:f1:97:31:d5:4c:20:33:bf:75:4a:
a0:e0:dc:69:11:ed:2a:b4:64:10:11:30:8b:0e:b0:
a7:10:d8:8a:c5:aa:1b:c8:26:8a:25:e7:66:9f:a5:
6a:1a:2f:7c:5f:83:c6:78:4f:1f
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Subject Key Identifier:
DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48
X509v3 Authority Key Identifier:
keyid:DD:BF:CA:DA:E6:D1:34:BA:37:75:21:CA:6F:9A:08:28:F2:35:B6:48
X509v3 Basic Constraints:
CA:TRUE
Signature Algorithm: sha256WithRSAEncryption
33:6a:54:d3:6b:c0:d7:01:5f:9d:f4:05:c1:93:66:90:50:d0:
b7:18:e9:b0:1e:4a:a0:b6:da:76:93:af:84:db:ad:15:54:31:
15:13:e4:de:7e:4e:0c:d5:09:1c:34:35:b6:e5:4c:d6:6f:65:
7d:32:5f:eb:fc:a9:6b:07:f7:49:82:e5:81:7e:07:80:9a:63:
f8:2c:c3:40:bc:8f:d4:2a:da:3e:d1:ee:08:b7:4d:a7:84:ca:
f4:3f:a1:98:45:be:b1:05:69:e7:df:d7:99:ab:1b:ee:8b:30:
cc:f7:fc:e7:d4:0b:17:ae:97:bf:e4:7b:fd:0f:a7:b4:85:79:
e3:59:e2:16:87:bf:1f:29:45:2c:23:93:76:be:c0:87:1d:de:
ec:2b:42:6a:e5:bb:c8:f4:0a:4a:08:0a:8c:5c:d8:7d:4d:d1:
b8:bf:d5:f7:29:ed:92:d1:94:04:e8:35:06:57:7f:2c:23:97:
87:a5:35:8d:26:d3:1a:47:f2:16:d7:d9:c6:d4:1f:23:43:d3:
26:99:39:ca:20:f4:71:23:6f:0c:4a:76:76:f7:76:1f:b3:fe:
bf:47:b0:fc:2a:56:81:e1:d2:dd:ee:08:d8:f4:ff:5a:dc:25:
61:8a:91:02:b9:86:1c:f2:50:73:76:25:35:fc:b6:25:26:15:
cb:eb:c4:2b:61:0c:1c:e7:ee:2f:17:9b:ec:f0:d4:a1:84:e7:
d2:af:de:e4:1b:24:14:a7:01:87:e3:ab:29:58:46:a0:d9:c0:
0a:e0:8d:d7:59:d3:1b:f8:54:20:3e:78:a5:a5:c8:4f:8b:03:
c4:96:9f:ec:fb:47:cf:76:2d:8d:65:34:27:bf:fa:ae:01:05:
8a:f3:92:0a:dd:89:6c:97:a1:c7:e7:60:51:e7:ac:eb:4b:7d:
2c:b8:65:c9:fe:5d:6a:48:55:8e:e4:c7:f9:6a:40:e1:b8:64:
45:e9:b5:59:29:a5:5f:cf:7d:58:7d:64:79:e5:a4:09:ac:1e:
76:65:3d:94:c4:68
-----BEGIN CERTIFICATE-----
MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV
BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx
NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg
Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI
hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1
nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo
4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ
ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc
3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr
rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7
cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU
+8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY
isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR
NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw
DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm
kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3
SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8
59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY
fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv
DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE
K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI
T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV
juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo
-----END CERTIFICATE-----

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/retry.bin поставляемый Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1 @@
]ôZ§µœÖæhõ0LÔý³y“'

Просмотреть файл

@ -0,0 +1,34 @@
-----BEGIN CERTIFICATE-----
MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV
BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx
NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj
MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv
Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k
YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA
3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug
U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2
pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA
hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC
WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU
NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3
EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB
wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV
HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E
FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK
b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst
gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0
Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw
AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD
VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0
Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/
uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY
oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb
iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0
KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP
IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr
+UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI
AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv
StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw==
-----END CERTIFICATE-----

Просмотреть файл

@ -0,0 +1,60 @@
-----BEGIN CERTIFICATE-----
MIIF8TCCBFmgAwIBAgIJAMstgJlaaVJcMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV
BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA3MDcx
NDIzMTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQHDA5DYXN0bGUgQW50aHJheDEj
MCEGA1UECgwaUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMMCWxv
Y2FsaG9zdDCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoCggGBAJ8oLzdB739k
YxZiFukBFGIpyjqYkj0I015p/sDz1MT7DljcZLBLy7OqnkLpB5tnM8256DwdihPA
3zlnfEzTfr9DD0qFBW2H5cMCoz7X17koeRhzGDd3dkjUeBjXvR5qRosG8wM3lQug
U7AizY+3Azaj1yN3mZ9K5a20jr58Kqinz+Xxx6sb2JfYYff2neJbBahNm5id0AD2
pi/TthZqO5DURJYo+MdgZOcy+7jEjOJsLWZd3Yzq78iM07qDjbpIoVpENZCTHTWA
hX8LIqz0OBmh4weQpm4+plU7E4r4D82uauocWw8iyuznCTtABWO7n9fWySmf9QZC
WYxHAFpBQs6zUVqAD7nhFdTqpQ9bRiaEnjE4HiAccPW+MAoSxFnv/rNzEzI6b4zU
NspFMfg1aNVamdjxdpUZ1GG1Okf0yPJykqEX4PZl3La1Be2q7YZ1wydR523Xd+f3
EO4/g+imETSKn8gyCf6Rvib175L4r2WV1CXQH7gFwZYCod6WHYq5TQIDAQABo4IB
wDCCAbwwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA4GA1UdDwEB/wQEAwIFoDAdBgNV
HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E
FgQUj+od4zNcABazi29rb9NMy7XLfFUwfQYDVR0jBHYwdIAU3b/K2ubRNLo3dSHK
b5oIKPI1tkihUaRPME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcoIJAMst
gJlaaVJbMIGDBggrBgEFBQcBAQR3MHUwPAYIKwYBBQUHMAKGMGh0dHA6Ly90ZXN0
Y2EucHl0aG9udGVzdC5uZXQvdGVzdGNhL3B5Y2FjZXJ0LmNlcjA1BggrBgEFBQcw
AYYpaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0Y2Evb2NzcC8wQwYD
VR0fBDwwOjA4oDagNIYyaHR0cDovL3Rlc3RjYS5weXRob250ZXN0Lm5ldC90ZXN0
Y2EvcmV2b2NhdGlvbi5jcmwwDQYJKoZIhvcNAQELBQADggGBACf1jFkQ9MbnKAC/
uo17EwPxHKZfswZVpCK527LVRr33DN1DbrR5ZWchDCpV7kCOhZ+fR7sKKk22ZHSY
oH+u3PEu20J3GOB1iyY1aMNB7WvId3JvappdVWkC/VpUyFfLsGUDFuIPADmZZqCb
iJMX4loteTVfl1d4xK/1mV6Gq9MRrRqiDfpSELn+v53OM9mGspwW+NZ1CIrbCuW0
KxZ/tPkqn8PSd9fNZR70bB7rWbnwrl+kH8xKxLl6qdlrMmg74WWwhLeQxK7+9DdP
IaDenzqx5cwWBGY/C0HcQj0gPuy3lSs1V/q+f7Y6uspPWP51PgiJLIywXS75iRAr
+UFGTzwAtyfTZSQoFyMmMULqfk6T5HtoVMqfRvPvK+mFDLWEstU1NIB1K/CRI7gI
AY65ClTU+zRS/tlF8IA7tsFvgtEf8jsI9kamlidhS1gyeg4dWcVErV4aeTPB1AUv
StPYQkKNM+NjytWHl5tNuBoDNLsc0gI/WSPiI4CIY8LwomOoiw==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIEbTCCAtWgAwIBAgIJAMstgJlaaVJbMA0GCSqGSIb3DQEBCwUAME0xCzAJBgNV
BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xODA4MjkxNDIzMTZaFw0yODA4MjYx
NDIzMTZaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg
Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCAaIwDQYJKoZI
hvcNAQEBBQADggGPADCCAYoCggGBAJftVUG6NheV23Ec0+FhrFhz48aWzysfuAj1
nUtLxzD2uAuzUnKgu8lNO47fIo4BV4HJknPMAMbscLA6F0DB3/KMNkzEp4HntiRo
4qB+NQcvoFv5RUb3HvBGEf7KGjxQ8SapX5winPhB4d9PEpUZL1yQARdufj59z+kJ
ryX4+EJ3LW1fNvJ4Hn1Kh2hjbAZxG436Jf7U0/WlF7Hv6hfLVMgnmYDLPEXxLFIc
3R9RRSBQHl6rV3MbQXiW3oSket2PMIU2WHl2oNJhyBuplJljxu74FL+0UlYxl/rr
rFOelc5MxFpKt8oDJ1s1V84C3OzKafiKWjnLFiAVAyRhbPR6/LZI5VkQXEnQI5/7
cV466WifNHKAJ7Y/TLHZ22N/Z2hKbhH46MD0WhY5Uwto3nf6Ref4kc14zSiUl3FU
+8/wN97JJsXcG56JbQmsyERxy23xlzHVTCAzv3VKoODcaRHtKrRkEBEwiw6wpxDY
isWqG8gmiiXnZp+lahovfF+DxnhPHwIDAQABo1AwTjAdBgNVHQ4EFgQU3b/K2ubR
NLo3dSHKb5oIKPI1tkgwHwYDVR0jBBgwFoAU3b/K2ubRNLo3dSHKb5oIKPI1tkgw
DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAYEAM2pU02vA1wFfnfQFwZNm
kFDQtxjpsB5KoLbadpOvhNutFVQxFRPk3n5ODNUJHDQ1tuVM1m9lfTJf6/ypawf3
SYLlgX4HgJpj+CzDQLyP1CraPtHuCLdNp4TK9D+hmEW+sQVp59/Xmasb7oswzPf8
59QLF66Xv+R7/Q+ntIV541niFoe/HylFLCOTdr7Ahx3e7CtCauW7yPQKSggKjFzY
fU3RuL/V9yntktGUBOg1Bld/LCOXh6U1jSbTGkfyFtfZxtQfI0PTJpk5yiD0cSNv
DEp2dvd2H7P+v0ew/CpWgeHS3e4I2PT/WtwlYYqRArmGHPJQc3YlNfy2JSYVy+vE
K2EMHOfuLxeb7PDUoYTn0q/e5BskFKcBh+OrKVhGoNnACuCN11nTG/hUID54paXI
T4sDxJaf7PtHz3YtjWU0J7/6rgEFivOSCt2JbJehx+dgUees60t9LLhlyf5dakhV
juTH+WpA4bhkRem1WSmlX899WH1keeWkCawedmU9lMRo
-----END CERTIFICATE-----

Просмотреть файл

@ -0,0 +1,40 @@
-----BEGIN PRIVATE KEY-----
MIIG/QIBADANBgkqhkiG9w0BAQEFAASCBucwggbjAgEAAoIBgQCfKC83Qe9/ZGMW
YhbpARRiKco6mJI9CNNeaf7A89TE+w5Y3GSwS8uzqp5C6QebZzPNueg8HYoTwN85
Z3xM036/Qw9KhQVth+XDAqM+19e5KHkYcxg3d3ZI1HgY170eakaLBvMDN5ULoFOw
Is2PtwM2o9cjd5mfSuWttI6+fCqop8/l8cerG9iX2GH39p3iWwWoTZuYndAA9qYv
07YWajuQ1ESWKPjHYGTnMvu4xIzibC1mXd2M6u/IjNO6g426SKFaRDWQkx01gIV/
CyKs9DgZoeMHkKZuPqZVOxOK+A/NrmrqHFsPIsrs5wk7QAVju5/X1skpn/UGQlmM
RwBaQULOs1FagA+54RXU6qUPW0YmhJ4xOB4gHHD1vjAKEsRZ7/6zcxMyOm+M1DbK
RTH4NWjVWpnY8XaVGdRhtTpH9MjycpKhF+D2Zdy2tQXtqu2GdcMnUedt13fn9xDu
P4PophE0ip/IMgn+kb4m9e+S+K9lldQl0B+4BcGWAqHelh2KuU0CAwEAAQKCAYEA
lKiWIYjmyRjdLKUGPTES9vWNvNmRjozV0RQ0LcoSbMMLDZkeO0UwyWqOVHUQ8+ib
jIcfEjeNJxI57oZopeHOO5vJhpNlFH+g7ltiW2qERqA1K88lSXm99Bzw6FNqhCRE
K8ub5N9fyfJA+P4o/xm0WK8EXk5yIUV17p/9zJJxzgKgv2jsVTi3QG2OZGvn4Oug
ByomMZEGHkBDzdxz8c/cP1Tlk1RFuwSgews178k2xq7AYSM/s0YmHi7b/RSvptX6
1v8P8kXNUe4AwTaNyrlvF2lwIadZ8h1hA7tCE2n44b7a7KfhAkwcbr1T59ioYh6P
zxsyPT678uD51dbtD/DXJCcoeeFOb8uzkR2KNcrnQzZpCJnRq4Gp5ybxwsxxuzpr
gz0gbNlhuWtE7EoSzmIK9t+WTS7IM2CvZymd6/OAh1Fuw6AQhSp64XRp3OfMMAAC
Ie2EPtKj4islWGT8VoUjuRYGmdRh4duAH1dkiAXOWA3R7y5a1/y/iE8KE8BtxocB
AoHBAM8aiURgpu1Fs0Oqz6izec7KSLL3l8hmW+MKUOfk/Ybng6FrTFsL5YtzR+Ap
wW4wwWnnIKEc1JLiZ7g8agRETK8hr5PwFXUn/GSWC0SMsazLJToySQS5LOV0tLzK
kJ3jtNU7tnlDGNkCHTHSoVL2T/8t+IkZI/h5Z6wjlYPvU2Iu0nVIXtiG+alv4A6M
Hrh9l5or4mjB6rGnVXeYohLkCm6s/W97ahVxLMcEdbsBo1prm2JqGnSoiR/tEFC/
QHQnbQKBwQDEu7kW0Yg9sZ89QtYtVQ1YpixFZORaUeRIRLnpEs1w7L1mCbOZ2Lj9
JHxsH05cYAc7HJfPwwxv3+3aGAIC/dfu4VSwEFtatAzUpzlhzKS5+HQCWB4JUNNU
MQ3+FwK2xQX4Ph8t+OzrFiYcK2g0An5UxWMa2HWIAWUOhnTOydAVsoH6yP31cVm4
0hxoABCwflaNLNGjRUyfBpLTAcNu/YtcE+KREy7YAAgXXrhRSO4XpLsSXwLnLT7/
YOkoBWDcTWECgcBPWnSUDZCIQ3efithMZJBciqd2Y2X19Dpq8O31HImD4jtOY0V7
cUB/wSkeHAGwjd/eCyA2e0x8B2IEdqmMfvr+86JJxekC3dJYXCFvH5WIhsH53YCa
3bT1KlWCLP9ib/g+58VQC0R/Cc9T4sfLePNH7D5ZkZd1wlbV30CPr+i8KwKay6MD
xhvtLx+jk07GE+E9wmjbCMo7TclyrLoVEOlqZMAqshgApT+p9eyCPetwXuDHwa3n
WxhHclcZCV7R4rUCgcAkdGSnxcvpIrDPOUNWwxvmAWTStw9ZbTNP8OxCNCm9cyDl
d4bAS1h8D/a+Uk7C70hnu7Sl2w7C7Eu2zhwRUdhhe3+l4GINPK/j99i6NqGPlGpq
xMlMEJ4YS768BqeKFpg0l85PRoEgTsphDeoROSUPsEPdBZ9BxIBlYKTkbKESZDGR
twzYHljx1n1NCDYPflmrb1KpXn4EOcObNghw2KqqNUUWfOeBPwBA1FxzM4BrAStp
DBINpGS4Dc0mjViVegECgcA3hTtm82XdxQXj9LQmb/E3lKx/7H87XIOeNMmvjYuZ
iS9wKrkF+u42vyoDxcKMCnxP5056wpdST4p56r+SBwVTHcc3lGBSGcMTIfwRXrj3
thOA2our2n4ouNIsYyTlcsQSzifwmpRmVMRPxl9fYVdEWUgB83FgHT0D9avvZnF9
t9OccnGJXShAIZIBADhVj/JwG4FbaX42NijD5PNpVLk1Y17OV0I576T9SfaQoBjJ
aH1M/zC4aVaS0DYB/Gxq7v8=
-----END PRIVATE KEY-----

Просмотреть файл

@ -0,0 +1,374 @@
import asyncio
import binascii
import random
import socket
from unittest import TestCase, skipIf
from unittest.mock import patch
from cryptography.hazmat.primitives import serialization
from aioquic.asyncio.client import connect
from aioquic.asyncio.protocol import QuicConnectionProtocol
from aioquic.asyncio.server import serve
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.logger import QuicLogger
from .utils import (
SERVER_CACERTFILE,
SERVER_CERTFILE,
SERVER_KEYFILE,
SKIP_TESTS,
generate_ec_certificate,
run,
)
real_sendto = socket.socket.sendto
def sendto_with_loss(self, data, addr=None):
"""
Simulate 25% packet loss.
"""
if random.random() > 0.25:
real_sendto(self, data, addr)
class SessionTicketStore:
def __init__(self):
self.tickets = {}
def add(self, ticket):
self.tickets[ticket.ticket] = ticket
def pop(self, label):
return self.tickets.pop(label, None)
def handle_stream(reader, writer):
async def serve():
data = await reader.read()
writer.write(bytes(reversed(data)))
writer.write_eof()
asyncio.ensure_future(serve())
class HighLevelTest(TestCase):
def setUp(self):
self.server = None
self.server_host = "localhost"
self.server_port = 4433
def tearDown(self):
if self.server is not None:
self.server.close()
async def run_client(
self,
host=None,
port=None,
cadata=None,
cafile=SERVER_CACERTFILE,
configuration=None,
request=b"ping",
**kwargs
):
if host is None:
host = self.server_host
if port is None:
port = self.server_port
if configuration is None:
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cadata=cadata, cafile=cafile)
async with connect(host, port, configuration=configuration, **kwargs) as client:
# waiting for connected when connected returns immediately
await client.wait_connected()
reader, writer = await client.create_stream()
self.assertEqual(writer.can_write_eof(), True)
self.assertEqual(writer.get_extra_info("stream_id"), 0)
writer.write(request)
writer.write_eof()
response = await reader.read()
# waiting for closed when closed returns immediately
await client.wait_closed()
return response
async def run_server(self, configuration=None, host="::", **kwargs):
if configuration is None:
configuration = QuicConfiguration(is_client=False)
configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)
self.server = await serve(
host=host,
port=self.server_port,
configuration=configuration,
stream_handler=handle_stream,
**kwargs
)
return self.server
def test_connect_and_serve(self):
run(self.run_server())
response = run(self.run_client())
self.assertEqual(response, b"gnip")
def test_connect_and_serve_ipv4(self):
run(self.run_server(host="0.0.0.0"))
response = run(self.run_client(host="127.0.0.1"))
self.assertEqual(response, b"gnip")
@skipIf("ipv6" in SKIP_TESTS, "Skipping IPv6 tests")
def test_connect_and_serve_ipv6(self):
run(self.run_server(host="::"))
response = run(self.run_client(host="::1"))
self.assertEqual(response, b"gnip")
def test_connect_and_serve_ec_certificate(self):
certificate, private_key = generate_ec_certificate(common_name="localhost")
run(
self.run_server(
configuration=QuicConfiguration(
certificate=certificate, private_key=private_key, is_client=False,
)
)
)
response = run(
self.run_client(
cadata=certificate.public_bytes(serialization.Encoding.PEM),
cafile=None,
)
)
self.assertEqual(response, b"gnip")
def test_connect_and_serve_large(self):
"""
Transfer enough data to require raising MAX_DATA and MAX_STREAM_DATA.
"""
data = b"Z" * 2097152
run(self.run_server())
response = run(self.run_client(request=data))
self.assertEqual(response, data)
def test_connect_and_serve_without_client_configuration(self):
async def run_client_without_config():
async with connect(self.server_host, self.server_port) as client:
await client.ping()
run(self.run_server())
with self.assertRaises(ConnectionError):
run(run_client_without_config())
def test_connect_and_serve_writelines(self):
async def run_client_writelines():
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host, self.server_port, configuration=configuration
) as client:
reader, writer = await client.create_stream()
assert writer.can_write_eof() is True
writer.writelines([b"01234567", b"89012345"])
writer.write_eof()
return await reader.read()
run(self.run_server())
response = run(run_client_writelines())
self.assertEqual(response, b"5432109876543210")
@skipIf("loss" in SKIP_TESTS, "Skipping loss tests")
@patch("socket.socket.sendto", new_callable=lambda: sendto_with_loss)
def test_connect_and_serve_with_packet_loss(self, mock_sendto):
"""
This test ensures handshake success and stream data is successfully sent
and received in the presence of packet loss (randomized 25% in each direction).
"""
data = b"Z" * 65536
server_configuration = QuicConfiguration(
is_client=False, quic_logger=QuicLogger()
)
server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)
run(self.run_server(configuration=server_configuration))
response = run(
self.run_client(
configuration=QuicConfiguration(
is_client=True, quic_logger=QuicLogger()
),
request=data,
)
)
self.assertEqual(response, data)
def test_connect_and_serve_with_session_ticket(self):
# start server
client_ticket = None
store = SessionTicketStore()
def save_ticket(t):
nonlocal client_ticket
client_ticket = t
run(
self.run_server(
session_ticket_fetcher=store.pop, session_ticket_handler=store.add
)
)
# first request
response = run(self.run_client(session_ticket_handler=save_ticket),)
self.assertEqual(response, b"gnip")
self.assertIsNotNone(client_ticket)
# second request
run(
self.run_client(
configuration=QuicConfiguration(
is_client=True, session_ticket=client_ticket
),
)
)
self.assertEqual(response, b"gnip")
def test_connect_and_serve_with_stateless_retry(self):
run(self.run_server())
response = run(self.run_client())
self.assertEqual(response, b"gnip")
def test_connect_and_serve_with_stateless_retry_bad_original_connection_id(self):
"""
If the server's transport parameters do not have the correct
original_connection_id the connection fail.
"""
def create_protocol(*args, **kwargs):
protocol = QuicConnectionProtocol(*args, **kwargs)
protocol._quic._original_connection_id = None
return protocol
run(self.run_server(create_protocol=create_protocol, stateless_retry=True))
with self.assertRaises(ConnectionError):
run(self.run_client())
@patch("aioquic.quic.retry.QuicRetryTokenHandler.validate_token")
def test_connect_and_serve_with_stateless_retry_bad(self, mock_validate):
mock_validate.side_effect = ValueError("Decryption failed.")
run(self.run_server(stateless_retry=True))
with self.assertRaises(ConnectionError):
run(
self.run_client(
configuration=QuicConfiguration(is_client=True, idle_timeout=4.0),
)
)
def test_connect_and_serve_with_version_negotiation(self):
run(self.run_server())
# force version negotiation
configuration = QuicConfiguration(is_client=True, quic_logger=QuicLogger())
configuration.supported_versions.insert(0, 0x1A2A3A4A)
response = run(self.run_client(configuration=configuration))
self.assertEqual(response, b"gnip")
def test_connect_timeout(self):
with self.assertRaises(ConnectionError):
run(
self.run_client(
port=4400,
configuration=QuicConfiguration(is_client=True, idle_timeout=5),
)
)
def test_connect_timeout_no_wait_connected(self):
async def run_client_no_wait_connected(configuration):
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host,
4400,
configuration=configuration,
wait_connected=False,
) as client:
await client.ping()
with self.assertRaises(ConnectionError):
run(
run_client_no_wait_connected(
configuration=QuicConfiguration(is_client=True, idle_timeout=5),
)
)
def test_connect_local_port(self):
run(self.run_server())
response = run(self.run_client(local_port=3456))
self.assertEqual(response, b"gnip")
def test_change_connection_id(self):
async def run_client_change_connection_id():
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host, self.server_port, configuration=configuration
) as client:
await client.ping()
client.change_connection_id()
await client.ping()
run(self.run_server())
run(run_client_change_connection_id())
def test_key_update(self):
async def run_client_key_update():
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host, self.server_port, configuration=configuration
) as client:
await client.ping()
client.request_key_update()
await client.ping()
run(self.run_server())
run(run_client_key_update())
def test_ping(self):
async def run_client_ping():
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host, self.server_port, configuration=configuration
) as client:
await client.ping()
await client.ping()
run(self.run_server())
run(run_client_ping())
def test_ping_parallel(self):
async def run_client_ping():
configuration = QuicConfiguration(is_client=True)
configuration.load_verify_locations(cafile=SERVER_CACERTFILE)
async with connect(
self.server_host, self.server_port, configuration=configuration
) as client:
coros = [client.ping() for x in range(16)]
await asyncio.gather(*coros)
run(self.run_server())
run(run_client_ping())
def test_server_receives_garbage(self):
server = run(self.run_server())
server.datagram_received(binascii.unhexlify("c00000000080"), ("1.2.3.4", 1234))
server.close()

Просмотреть файл

@ -0,0 +1,38 @@
import asyncio
from unittest import TestCase
from aioquic.asyncio.compat import _asynccontextmanager
from .utils import run
@_asynccontextmanager
async def some_context():
await asyncio.sleep(0)
yield
await asyncio.sleep(0)
class AsyncioCompatTest(TestCase):
def test_ok(self):
async def test():
async with some_context():
pass
run(test())
def test_raise_exception(self):
async def test():
async with some_context():
raise RuntimeError("some reason")
with self.assertRaises(RuntimeError):
run(test())
def test_raise_exception_type(self):
async def test():
async with some_context():
raise RuntimeError
with self.assertRaises(RuntimeError):
run(test())

Просмотреть файл

@ -0,0 +1,201 @@
from unittest import TestCase
from aioquic.buffer import Buffer, BufferReadError, BufferWriteError, size_uint_var
class BufferTest(TestCase):
def test_data_slice(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.data_slice(0, 8), b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.data_slice(1, 3), b"\x07\x06")
with self.assertRaises(BufferReadError):
buf.data_slice(-1, 3)
with self.assertRaises(BufferReadError):
buf.data_slice(0, 9)
with self.assertRaises(BufferReadError):
buf.data_slice(1, 0)
def test_pull_bytes(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_bytes(3), b"\x08\x07\x06")
def test_pull_bytes_negative(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
with self.assertRaises(BufferReadError):
buf.pull_bytes(-1)
def test_pull_bytes_truncated(self):
buf = Buffer(capacity=0)
with self.assertRaises(BufferReadError):
buf.pull_bytes(2)
self.assertEqual(buf.tell(), 0)
def test_pull_bytes_zero(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_bytes(0), b"")
def test_pull_uint8(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_uint8(), 0x08)
self.assertEqual(buf.tell(), 1)
def test_pull_uint8_truncated(self):
buf = Buffer(capacity=0)
with self.assertRaises(BufferReadError):
buf.pull_uint8()
self.assertEqual(buf.tell(), 0)
def test_pull_uint16(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_uint16(), 0x0807)
self.assertEqual(buf.tell(), 2)
def test_pull_uint16_truncated(self):
buf = Buffer(capacity=1)
with self.assertRaises(BufferReadError):
buf.pull_uint16()
self.assertEqual(buf.tell(), 0)
def test_pull_uint32(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_uint32(), 0x08070605)
self.assertEqual(buf.tell(), 4)
def test_pull_uint32_truncated(self):
buf = Buffer(capacity=3)
with self.assertRaises(BufferReadError):
buf.pull_uint32()
self.assertEqual(buf.tell(), 0)
def test_pull_uint64(self):
buf = Buffer(data=b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.pull_uint64(), 0x0807060504030201)
self.assertEqual(buf.tell(), 8)
def test_pull_uint64_truncated(self):
buf = Buffer(capacity=7)
with self.assertRaises(BufferReadError):
buf.pull_uint64()
self.assertEqual(buf.tell(), 0)
def test_push_bytes(self):
buf = Buffer(capacity=3)
buf.push_bytes(b"\x08\x07\x06")
self.assertEqual(buf.data, b"\x08\x07\x06")
self.assertEqual(buf.tell(), 3)
def test_push_bytes_truncated(self):
buf = Buffer(capacity=3)
with self.assertRaises(BufferWriteError):
buf.push_bytes(b"\x08\x07\x06\x05")
self.assertEqual(buf.tell(), 0)
def test_push_bytes_zero(self):
buf = Buffer(capacity=3)
buf.push_bytes(b"")
self.assertEqual(buf.data, b"")
self.assertEqual(buf.tell(), 0)
def test_push_uint8(self):
buf = Buffer(capacity=1)
buf.push_uint8(0x08)
self.assertEqual(buf.data, b"\x08")
self.assertEqual(buf.tell(), 1)
def test_push_uint16(self):
buf = Buffer(capacity=2)
buf.push_uint16(0x0807)
self.assertEqual(buf.data, b"\x08\x07")
self.assertEqual(buf.tell(), 2)
def test_push_uint32(self):
buf = Buffer(capacity=4)
buf.push_uint32(0x08070605)
self.assertEqual(buf.data, b"\x08\x07\x06\x05")
self.assertEqual(buf.tell(), 4)
def test_push_uint64(self):
buf = Buffer(capacity=8)
buf.push_uint64(0x0807060504030201)
self.assertEqual(buf.data, b"\x08\x07\x06\x05\x04\x03\x02\x01")
self.assertEqual(buf.tell(), 8)
def test_seek(self):
buf = Buffer(data=b"01234567")
self.assertFalse(buf.eof())
self.assertEqual(buf.tell(), 0)
buf.seek(4)
self.assertFalse(buf.eof())
self.assertEqual(buf.tell(), 4)
buf.seek(8)
self.assertTrue(buf.eof())
self.assertEqual(buf.tell(), 8)
with self.assertRaises(BufferReadError):
buf.seek(-1)
self.assertEqual(buf.tell(), 8)
with self.assertRaises(BufferReadError):
buf.seek(9)
self.assertEqual(buf.tell(), 8)
class UintVarTest(TestCase):
def roundtrip(self, data, value):
buf = Buffer(data=data)
self.assertEqual(buf.pull_uint_var(), value)
self.assertEqual(buf.tell(), len(data))
buf = Buffer(capacity=8)
buf.push_uint_var(value)
self.assertEqual(buf.data, data)
def test_uint_var(self):
# 1 byte
self.roundtrip(b"\x00", 0)
self.roundtrip(b"\x01", 1)
self.roundtrip(b"\x25", 37)
self.roundtrip(b"\x3f", 63)
# 2 bytes
self.roundtrip(b"\x7b\xbd", 15293)
self.roundtrip(b"\x7f\xff", 16383)
# 4 bytes
self.roundtrip(b"\x9d\x7f\x3e\x7d", 494878333)
self.roundtrip(b"\xbf\xff\xff\xff", 1073741823)
# 8 bytes
self.roundtrip(b"\xc2\x19\x7c\x5e\xff\x14\xe8\x8c", 151288809941952652)
self.roundtrip(b"\xff\xff\xff\xff\xff\xff\xff\xff", 4611686018427387903)
def test_pull_uint_var_truncated(self):
buf = Buffer(capacity=0)
with self.assertRaises(BufferReadError):
buf.pull_uint_var()
buf = Buffer(data=b"\xff")
with self.assertRaises(BufferReadError):
buf.pull_uint_var()
def test_push_uint_var_too_big(self):
buf = Buffer(capacity=8)
with self.assertRaises(ValueError) as cm:
buf.push_uint_var(4611686018427387904)
self.assertEqual(
str(cm.exception), "Integer is too big for a variable-length integer"
)
def test_size_uint_var(self):
self.assertEqual(size_uint_var(63), 1)
self.assertEqual(size_uint_var(16383), 2)
self.assertEqual(size_uint_var(1073741823), 4)
self.assertEqual(size_uint_var(4611686018427387903), 8)
with self.assertRaises(ValueError) as cm:
size_uint_var(4611686018427387904)
self.assertEqual(
str(cm.exception), "Integer is too big for a variable-length integer"
)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,319 @@
import binascii
from unittest import TestCase, skipIf
from aioquic.buffer import Buffer
from aioquic.quic.crypto import (
INITIAL_CIPHER_SUITE,
CryptoError,
CryptoPair,
derive_key_iv_hp,
)
from aioquic.quic.packet import PACKET_FIXED_BIT, QuicProtocolVersion
from aioquic.tls import CipherSuite
from .utils import SKIP_TESTS
PROTOCOL_VERSION = QuicProtocolVersion.DRAFT_25
CHACHA20_CLIENT_PACKET_NUMBER = 2
CHACHA20_CLIENT_PLAIN_HEADER = binascii.unhexlify(
"e1ff0000160880b57c7b70d8524b0850fc2a28e240fd7640170002"
)
CHACHA20_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify("0201000000")
CHACHA20_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify(
"e8ff0000160880b57c7b70d8524b0850fc2a28e240fd7640178313b04be98449"
"eb10567e25ce930381f2a5b7da2db8db"
)
LONG_CLIENT_PACKET_NUMBER = 2
LONG_CLIENT_PLAIN_HEADER = binascii.unhexlify(
"c3ff000017088394c8f03e5157080000449e00000002"
)
LONG_CLIENT_PLAIN_PAYLOAD = binascii.unhexlify(
"060040c4010000c003036660261ff947cea49cce6cfad687f457cf1b14531ba1"
"4131a0e8f309a1d0b9c4000006130113031302010000910000000b0009000006"
"736572766572ff01000100000a00140012001d00170018001901000101010201"
"03010400230000003300260024001d00204cfdfcd178b784bf328cae793b136f"
"2aedce005ff183d7bb1495207236647037002b0003020304000d0020001e0403"
"05030603020308040805080604010501060102010402050206020202002d0002"
"0101001c00024001"
) + bytes(962)
LONG_CLIENT_ENCRYPTED_PACKET = binascii.unhexlify(
"c0ff000017088394c8f03e5157080000449e3b343aa8535064a4268a0d9d7b1c"
"9d250ae355162276e9b1e3011ef6bbc0ab48ad5bcc2681e953857ca62becd752"
"4daac473e68d7405fbba4e9ee616c87038bdbe908c06d9605d9ac49030359eec"
"b1d05a14e117db8cede2bb09d0dbbfee271cb374d8f10abec82d0f59a1dee29f"
"e95638ed8dd41da07487468791b719c55c46968eb3b54680037102a28e53dc1d"
"12903db0af5821794b41c4a93357fa59ce69cfe7f6bdfa629eef78616447e1d6"
"11c4baf71bf33febcb03137c2c75d25317d3e13b684370f668411c0f00304b50"
"1c8fd422bd9b9ad81d643b20da89ca0525d24d2b142041cae0af205092e43008"
"0cd8559ea4c5c6e4fa3f66082b7d303e52ce0162baa958532b0bbc2bc785681f"
"cf37485dff6595e01e739c8ac9efba31b985d5f656cc092432d781db95221724"
"87641c4d3ab8ece01e39bc85b15436614775a98ba8fa12d46f9b35e2a55eb72d"
"7f85181a366663387ddc20551807e007673bd7e26bf9b29b5ab10a1ca87cbb7a"
"d97e99eb66959c2a9bc3cbde4707ff7720b110fa95354674e395812e47a0ae53"
"b464dcb2d1f345df360dc227270c750676f6724eb479f0d2fbb6124429990457"
"ac6c9167f40aab739998f38b9eccb24fd47c8410131bf65a52af841275d5b3d1"
"880b197df2b5dea3e6de56ebce3ffb6e9277a82082f8d9677a6767089b671ebd"
"244c214f0bde95c2beb02cd1172d58bdf39dce56ff68eb35ab39b49b4eac7c81"
"5ea60451d6e6ab82119118df02a586844a9ffe162ba006d0669ef57668cab38b"
"62f71a2523a084852cd1d079b3658dc2f3e87949b550bab3e177cfc49ed190df"
"f0630e43077c30de8f6ae081537f1e83da537da980afa668e7b7fb25301cf741"
"524be3c49884b42821f17552fbd1931a813017b6b6590a41ea18b6ba49cd48a4"
"40bd9a3346a7623fb4ba34a3ee571e3c731f35a7a3cf25b551a680fa68763507"
"b7fde3aaf023c50b9d22da6876ba337eb5e9dd9ec3daf970242b6c5aab3aa4b2"
"96ad8b9f6832f686ef70fa938b31b4e5ddd7364442d3ea72e73d668fb0937796"
"f462923a81a47e1cee7426ff6d9221269b5a62ec03d6ec94d12606cb485560ba"
"b574816009e96504249385bb61a819be04f62c2066214d8360a2022beb316240"
"b6c7d78bbe56c13082e0ca272661210abf020bf3b5783f1426436cf9ff418405"
"93a5d0638d32fc51c5c65ff291a3a7a52fd6775e623a4439cc08dd25582febc9"
"44ef92d8dbd329c91de3e9c9582e41f17f3d186f104ad3f90995116c682a2a14"
"a3b4b1f547c335f0be710fc9fc03e0e587b8cda31ce65b969878a4ad4283e6d5"
"b0373f43da86e9e0ffe1ae0fddd3516255bd74566f36a38703d5f34249ded1f6"
"6b3d9b45b9af2ccfefe984e13376b1b2c6404aa48c8026132343da3f3a33659e"
"c1b3e95080540b28b7f3fcd35fa5d843b579a84c089121a60d8c1754915c344e"
"eaf45a9bf27dc0c1e78416169122091313eb0e87555abd706626e557fc36a04f"
"cd191a58829104d6075c5594f627ca506bf181daec940f4a4f3af0074eee89da"
"acde6758312622d4fa675b39f728e062d2bee680d8f41a597c262648bb18bcfc"
"13c8b3d97b1a77b2ac3af745d61a34cc4709865bac824a94bb19058015e4e42d"
"c9be6c7803567321829dd85853396269"
)
LONG_SERVER_PACKET_NUMBER = 1
LONG_SERVER_PLAIN_HEADER = binascii.unhexlify(
"c1ff0000170008f067a5502a4262b50040740001"
)
LONG_SERVER_PLAIN_PAYLOAD = binascii.unhexlify(
"0d0000000018410a020000560303eefce7f7b37ba1d1632e96677825ddf73988"
"cfc79825df566dc5430b9a045a1200130100002e00330024001d00209d3c940d"
"89690b84d08a60993c144eca684d1081287c834d5311bcf32bb9da1a002b0002"
"0304"
)
LONG_SERVER_ENCRYPTED_PACKET = binascii.unhexlify(
"c9ff0000170008f067a5502a4262b5004074168bf22b7002596f99ae67abf65a"
"5852f54f58c37c808682e2e40492d8a3899fb04fc0afe9aabc8767b18a0aa493"
"537426373b48d502214dd856d63b78cee37bc664b3fe86d487ac7a77c53038a3"
"cd32f0b5004d9f5754c4f7f2d1f35cf3f7116351c92b9cf9bb6d091ddfc8b32d"
"432348a2c413"
)
SHORT_SERVER_PACKET_NUMBER = 3
SHORT_SERVER_PLAIN_HEADER = binascii.unhexlify("41b01fd24a586a9cf30003")
SHORT_SERVER_PLAIN_PAYLOAD = binascii.unhexlify(
"06003904000035000151805a4bebf5000020b098c8dc4183e4c182572e10ac3e"
"2b88897e0524c8461847548bd2dffa2c0ae60008002a0004ffffffff"
)
SHORT_SERVER_ENCRYPTED_PACKET = binascii.unhexlify(
"5db01fd24a586a9cf33dec094aaec6d6b4b7a5e15f5a3f05d06cf1ad0355c19d"
"cce0807eecf7bf1c844a66e1ecd1f74b2a2d69bfd25d217833edd973246597bd"
"5107ea15cb1e210045396afa602fe23432f4ab24ce251b"
)
class CryptoTest(TestCase):
"""
Test vectors from:
https://tools.ietf.org/html/draft-ietf-quic-tls-18#appendix-A
"""
def create_crypto(self, is_client):
pair = CryptoPair()
pair.setup_initial(
cid=binascii.unhexlify("8394c8f03e515708"),
is_client=is_client,
version=PROTOCOL_VERSION,
)
return pair
def test_derive_key_iv_hp(self):
# client
secret = binascii.unhexlify(
"8a3515a14ae3c31b9c2d6d5bc58538ca5cd2baa119087143e60887428dcb52f6"
)
key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret)
self.assertEqual(key, binascii.unhexlify("98b0d7e5e7a402c67c33f350fa65ea54"))
self.assertEqual(iv, binascii.unhexlify("19e94387805eb0b46c03a788"))
self.assertEqual(hp, binascii.unhexlify("0edd982a6ac527f2eddcbb7348dea5d7"))
# server
secret = binascii.unhexlify(
"47b2eaea6c266e32c0697a9e2a898bdf5c4fb3e5ac34f0e549bf2c58581a3811"
)
key, iv, hp = derive_key_iv_hp(INITIAL_CIPHER_SUITE, secret)
self.assertEqual(key, binascii.unhexlify("9a8be902a9bdd91d16064ca118045fb4"))
self.assertEqual(iv, binascii.unhexlify("0a82086d32205ba22241d8dc"))
self.assertEqual(hp, binascii.unhexlify("94b9452d2b3c7c7f6da7fdd8593537fd"))
@skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests")
def test_decrypt_chacha20(self):
pair = CryptoPair()
pair.recv.setup(
cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256,
secret=binascii.unhexlify(
"b42772df33c9719a32820d302aa664d080d7f5ea7a71a330f87864cb289ae8c0"
),
version=PROTOCOL_VERSION,
)
plain_header, plain_payload, packet_number = pair.decrypt_packet(
CHACHA20_CLIENT_ENCRYPTED_PACKET, 25, 0
)
self.assertEqual(plain_header, CHACHA20_CLIENT_PLAIN_HEADER)
self.assertEqual(plain_payload, CHACHA20_CLIENT_PLAIN_PAYLOAD)
self.assertEqual(packet_number, CHACHA20_CLIENT_PACKET_NUMBER)
def test_decrypt_long_client(self):
pair = self.create_crypto(is_client=False)
plain_header, plain_payload, packet_number = pair.decrypt_packet(
LONG_CLIENT_ENCRYPTED_PACKET, 18, 0
)
self.assertEqual(plain_header, LONG_CLIENT_PLAIN_HEADER)
self.assertEqual(plain_payload, LONG_CLIENT_PLAIN_PAYLOAD)
self.assertEqual(packet_number, LONG_CLIENT_PACKET_NUMBER)
def test_decrypt_long_server(self):
pair = self.create_crypto(is_client=True)
plain_header, plain_payload, packet_number = pair.decrypt_packet(
LONG_SERVER_ENCRYPTED_PACKET, 18, 0
)
self.assertEqual(plain_header, LONG_SERVER_PLAIN_HEADER)
self.assertEqual(plain_payload, LONG_SERVER_PLAIN_PAYLOAD)
self.assertEqual(packet_number, LONG_SERVER_PACKET_NUMBER)
def test_decrypt_no_key(self):
pair = CryptoPair()
with self.assertRaises(CryptoError):
pair.decrypt_packet(LONG_SERVER_ENCRYPTED_PACKET, 18, 0)
def test_decrypt_short_server(self):
pair = CryptoPair()
pair.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=binascii.unhexlify(
"310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100"
),
version=PROTOCOL_VERSION,
)
plain_header, plain_payload, packet_number = pair.decrypt_packet(
SHORT_SERVER_ENCRYPTED_PACKET, 9, 0
)
self.assertEqual(plain_header, SHORT_SERVER_PLAIN_HEADER)
self.assertEqual(plain_payload, SHORT_SERVER_PLAIN_PAYLOAD)
self.assertEqual(packet_number, SHORT_SERVER_PACKET_NUMBER)
@skipIf("chacha20" in SKIP_TESTS, "Skipping chacha20 tests")
def test_encrypt_chacha20(self):
pair = CryptoPair()
pair.send.setup(
cipher_suite=CipherSuite.CHACHA20_POLY1305_SHA256,
secret=binascii.unhexlify(
"b42772df33c9719a32820d302aa664d080d7f5ea7a71a330f87864cb289ae8c0"
),
version=PROTOCOL_VERSION,
)
packet = pair.encrypt_packet(
CHACHA20_CLIENT_PLAIN_HEADER,
CHACHA20_CLIENT_PLAIN_PAYLOAD,
CHACHA20_CLIENT_PACKET_NUMBER,
)
self.assertEqual(packet, CHACHA20_CLIENT_ENCRYPTED_PACKET)
def test_encrypt_long_client(self):
pair = self.create_crypto(is_client=True)
packet = pair.encrypt_packet(
LONG_CLIENT_PLAIN_HEADER,
LONG_CLIENT_PLAIN_PAYLOAD,
LONG_CLIENT_PACKET_NUMBER,
)
self.assertEqual(packet, LONG_CLIENT_ENCRYPTED_PACKET)
def test_encrypt_long_server(self):
pair = self.create_crypto(is_client=False)
packet = pair.encrypt_packet(
LONG_SERVER_PLAIN_HEADER,
LONG_SERVER_PLAIN_PAYLOAD,
LONG_SERVER_PACKET_NUMBER,
)
self.assertEqual(packet, LONG_SERVER_ENCRYPTED_PACKET)
def test_encrypt_short_server(self):
pair = CryptoPair()
pair.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=binascii.unhexlify(
"310281977cb8c1c1c1212d784b2d29e5a6489e23de848d370a5a2f9537f3a100"
),
version=PROTOCOL_VERSION,
)
packet = pair.encrypt_packet(
SHORT_SERVER_PLAIN_HEADER,
SHORT_SERVER_PLAIN_PAYLOAD,
SHORT_SERVER_PACKET_NUMBER,
)
self.assertEqual(packet, SHORT_SERVER_ENCRYPTED_PACKET)
def test_key_update(self):
pair1 = self.create_crypto(is_client=True)
pair2 = self.create_crypto(is_client=False)
def create_packet(key_phase, packet_number):
buf = Buffer(capacity=100)
buf.push_uint8(PACKET_FIXED_BIT | key_phase << 2 | 1)
buf.push_bytes(binascii.unhexlify("8394c8f03e515708"))
buf.push_uint16(packet_number)
return buf.data, b"\x00\x01\x02\x03"
def send(sender, receiver, packet_number=0):
plain_header, plain_payload = create_packet(
key_phase=sender.key_phase, packet_number=packet_number
)
encrypted = sender.encrypt_packet(
plain_header, plain_payload, packet_number
)
recov_header, recov_payload, recov_packet_number = receiver.decrypt_packet(
encrypted, len(plain_header) - 2, 0
)
self.assertEqual(recov_header, plain_header)
self.assertEqual(recov_payload, plain_payload)
self.assertEqual(recov_packet_number, packet_number)
# roundtrip
send(pair1, pair2, 0)
send(pair2, pair1, 0)
self.assertEqual(pair1.key_phase, 0)
self.assertEqual(pair2.key_phase, 0)
# pair 1 key update
pair1.update_key()
# roundtrip
send(pair1, pair2, 1)
send(pair2, pair1, 1)
self.assertEqual(pair1.key_phase, 1)
self.assertEqual(pair2.key_phase, 1)
# pair 2 key update
pair2.update_key()
# roundtrip
send(pair2, pair1, 2)
send(pair1, pair2, 2)
self.assertEqual(pair1.key_phase, 0)
self.assertEqual(pair2.key_phase, 0)
# pair 1 key - update, but not next to send
pair1.update_key()
# roundtrip
send(pair2, pair1, 3)
send(pair1, pair2, 3)
self.assertEqual(pair1.key_phase, 1)
self.assertEqual(pair2.key_phase, 1)

Просмотреть файл

@ -0,0 +1,148 @@
from unittest import TestCase
from aioquic.h0.connection import H0_ALPN, H0Connection
from aioquic.h3.events import DataReceived, HeadersReceived
from .test_connection import client_and_server, transfer
def h0_client_and_server():
return client_and_server(
client_options={"alpn_protocols": H0_ALPN},
server_options={"alpn_protocols": H0_ALPN},
)
def h0_transfer(quic_sender, h0_receiver):
quic_receiver = h0_receiver._quic
transfer(quic_sender, quic_receiver)
# process QUIC events
http_events = []
event = quic_receiver.next_event()
while event is not None:
http_events.extend(h0_receiver.handle_event(event))
event = quic_receiver.next_event()
return http_events
class H0ConnectionTest(TestCase):
def test_connect(self):
with h0_client_and_server() as (quic_client, quic_server):
h0_client = H0Connection(quic_client)
h0_server = H0Connection(quic_server)
# send request
stream_id = quic_client.get_next_available_stream_id()
h0_client.send_headers(
stream_id=stream_id,
headers=[
(b":method", b"GET"),
(b":scheme", b"https"),
(b":authority", b"localhost"),
(b":path", b"/"),
],
)
h0_client.send_data(stream_id=stream_id, data=b"", end_stream=True)
# receive request
events = h0_transfer(quic_client, h0_server)
self.assertEqual(len(events), 2)
self.assertTrue(isinstance(events[0], HeadersReceived))
self.assertEqual(
events[0].headers, [(b":method", b"GET"), (b":path", b"/")]
)
self.assertEqual(events[0].stream_id, stream_id)
self.assertEqual(events[0].stream_ended, False)
self.assertTrue(isinstance(events[1], DataReceived))
self.assertEqual(events[1].data, b"")
self.assertEqual(events[1].stream_id, stream_id)
self.assertEqual(events[1].stream_ended, True)
# send response
h0_server.send_headers(
stream_id=stream_id,
headers=[
(b":status", b"200"),
(b"content-type", b"text/html; charset=utf-8"),
],
)
h0_server.send_data(
stream_id=stream_id,
data=b"<html><body>hello</body></html>",
end_stream=True,
)
# receive response
events = h0_transfer(quic_server, h0_client)
self.assertEqual(len(events), 2)
self.assertTrue(isinstance(events[0], HeadersReceived))
self.assertEqual(events[0].headers, [])
self.assertEqual(events[0].stream_id, stream_id)
self.assertEqual(events[0].stream_ended, False)
self.assertTrue(isinstance(events[1], DataReceived))
self.assertEqual(events[1].data, b"<html><body>hello</body></html>")
self.assertEqual(events[1].stream_id, stream_id)
self.assertEqual(events[1].stream_ended, True)
def test_headers_only(self):
with h0_client_and_server() as (quic_client, quic_server):
h0_client = H0Connection(quic_client)
h0_server = H0Connection(quic_server)
# send request
stream_id = quic_client.get_next_available_stream_id()
h0_client.send_headers(
stream_id=stream_id,
headers=[
(b":method", b"HEAD"),
(b":scheme", b"https"),
(b":authority", b"localhost"),
(b":path", b"/"),
],
end_stream=True,
)
# receive request
events = h0_transfer(quic_client, h0_server)
self.assertEqual(len(events), 2)
self.assertTrue(isinstance(events[0], HeadersReceived))
self.assertEqual(
events[0].headers, [(b":method", b"HEAD"), (b":path", b"/")]
)
self.assertEqual(events[0].stream_id, stream_id)
self.assertEqual(events[0].stream_ended, False)
self.assertTrue(isinstance(events[1], DataReceived))
self.assertEqual(events[1].data, b"")
self.assertEqual(events[1].stream_id, stream_id)
self.assertEqual(events[1].stream_ended, True)
# send response
h0_server.send_headers(
stream_id=stream_id,
headers=[
(b":status", b"200"),
(b"content-type", b"text/html; charset=utf-8"),
],
end_stream=True,
)
# receive response
events = h0_transfer(quic_server, h0_client)
self.assertEqual(len(events), 2)
self.assertTrue(isinstance(events[0], HeadersReceived))
self.assertEqual(events[0].headers, [])
self.assertEqual(events[0].stream_id, stream_id)
self.assertEqual(events[0].stream_ended, False)
self.assertTrue(isinstance(events[1], DataReceived))
self.assertEqual(events[1].data, b"")
self.assertEqual(events[1].stream_id, stream_id)
self.assertEqual(events[1].stream_ended, True)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,37 @@
from unittest import TestCase
from aioquic.quic.logger import QuicLogger
class QuicLoggerTest(TestCase):
def test_empty(self):
logger = QuicLogger()
self.assertEqual(logger.to_dict(), {"qlog_version": "draft-01", "traces": []})
def test_empty_trace(self):
logger = QuicLogger()
trace = logger.start_trace(is_client=True, odcid=bytes(8))
logger.end_trace(trace)
self.assertEqual(
logger.to_dict(),
{
"qlog_version": "draft-01",
"traces": [
{
"common_fields": {
"ODCID": "0000000000000000",
"reference_time": "0",
},
"configuration": {"time_units": "us"},
"event_fields": [
"relative_time",
"category",
"event_type",
"data",
],
"events": [],
"vantage_point": {"name": "aioquic", "type": "client"},
}
],
},
)

Просмотреть файл

@ -0,0 +1,513 @@
import binascii
from unittest import TestCase
from aioquic.buffer import Buffer, BufferReadError
from aioquic.quic import packet
from aioquic.quic.packet import (
PACKET_TYPE_INITIAL,
PACKET_TYPE_RETRY,
QuicPreferredAddress,
QuicProtocolVersion,
QuicTransportParameters,
decode_packet_number,
encode_quic_version_negotiation,
get_retry_integrity_tag,
pull_quic_header,
pull_quic_preferred_address,
pull_quic_transport_parameters,
push_quic_preferred_address,
push_quic_transport_parameters,
)
from .utils import load
class PacketTest(TestCase):
def test_decode_packet_number(self):
# expected = 0
for i in range(0, 256):
self.assertEqual(decode_packet_number(i, 8, expected=0), i)
# expected = 128
self.assertEqual(decode_packet_number(0, 8, expected=128), 256)
for i in range(1, 256):
self.assertEqual(decode_packet_number(i, 8, expected=128), i)
# expected = 129
self.assertEqual(decode_packet_number(0, 8, expected=129), 256)
self.assertEqual(decode_packet_number(1, 8, expected=129), 257)
for i in range(2, 256):
self.assertEqual(decode_packet_number(i, 8, expected=129), i)
# expected = 256
for i in range(0, 128):
self.assertEqual(decode_packet_number(i, 8, expected=256), 256 + i)
for i in range(129, 256):
self.assertEqual(decode_packet_number(i, 8, expected=256), i)
def test_pull_empty(self):
buf = Buffer(data=b"")
with self.assertRaises(BufferReadError):
pull_quic_header(buf, host_cid_length=8)
def test_pull_initial_client(self):
buf = Buffer(data=load("initial_client.bin"))
header = pull_quic_header(buf, host_cid_length=8)
self.assertTrue(header.is_long_header)
self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25)
self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL)
self.assertEqual(header.destination_cid, binascii.unhexlify("858b39368b8e3c6e"))
self.assertEqual(header.source_cid, b"")
self.assertEqual(header.token, b"")
self.assertEqual(header.integrity_tag, b"")
self.assertEqual(header.rest_length, 1262)
self.assertEqual(buf.tell(), 18)
def test_pull_initial_server(self):
buf = Buffer(data=load("initial_server.bin"))
header = pull_quic_header(buf, host_cid_length=8)
self.assertTrue(header.is_long_header)
self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25)
self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL)
self.assertEqual(header.destination_cid, b"")
self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479"))
self.assertEqual(header.token, b"")
self.assertEqual(header.integrity_tag, b"")
self.assertEqual(header.rest_length, 184)
self.assertEqual(buf.tell(), 18)
def test_pull_retry(self):
buf = Buffer(data=load("retry.bin"))
header = pull_quic_header(buf, host_cid_length=8)
self.assertTrue(header.is_long_header)
self.assertEqual(header.version, QuicProtocolVersion.DRAFT_25)
self.assertEqual(header.packet_type, PACKET_TYPE_RETRY)
self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e"))
self.assertEqual(
header.source_cid,
binascii.unhexlify("0b0a205a648fcf82d85f128b67bbe08053e6"),
)
self.assertEqual(
header.token,
binascii.unhexlify(
"44397a35d698393c134b08a932737859f446d3aadd00ed81540c8d8de172"
"906d3e7a111b503f9729b8928e7528f9a86a4581f9ebb4cb3b53c283661e"
"8530741a99192ee56914c5626998ec0f"
),
)
self.assertEqual(
header.integrity_tag, binascii.unhexlify("e1a3c80c797ea401c08fc9c342a2d90d")
)
self.assertEqual(header.rest_length, 0)
self.assertEqual(buf.tell(), 125)
# check integrity
self.assertEqual(
get_retry_integrity_tag(
buf.data_slice(0, 109), binascii.unhexlify("fbbd219b7363b64b"),
),
header.integrity_tag,
)
def test_pull_version_negotiation(self):
buf = Buffer(data=load("version_negotiation.bin"))
header = pull_quic_header(buf, host_cid_length=8)
self.assertTrue(header.is_long_header)
self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION)
self.assertEqual(header.packet_type, None)
self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849"))
self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1"))
self.assertEqual(header.token, b"")
self.assertEqual(header.integrity_tag, b"")
self.assertEqual(header.rest_length, 8)
self.assertEqual(buf.tell(), 23)
def test_pull_long_header_dcid_too_long(self):
buf = Buffer(
data=binascii.unhexlify(
"c6ff0000161500000000000000000000000000000000000000000000004"
"01c514f99ec4bbf1f7a30f9b0c94fef717f1c1d07fec24c99a864da7ede"
)
)
with self.assertRaises(ValueError) as cm:
pull_quic_header(buf, host_cid_length=8)
self.assertEqual(str(cm.exception), "Destination CID is too long (21 bytes)")
def test_pull_long_header_scid_too_long(self):
buf = Buffer(
data=binascii.unhexlify(
"c2ff0000160015000000000000000000000000000000000000000000004"
"01cfcee99ec4bbf1f7a30f9b0c9417b8c263cdd8cc972a4439d68a46320"
)
)
with self.assertRaises(ValueError) as cm:
pull_quic_header(buf, host_cid_length=8)
self.assertEqual(str(cm.exception), "Source CID is too long (21 bytes)")
def test_pull_long_header_no_fixed_bit(self):
buf = Buffer(data=b"\x80\xff\x00\x00\x11\x00\x00")
with self.assertRaises(ValueError) as cm:
pull_quic_header(buf, host_cid_length=8)
self.assertEqual(str(cm.exception), "Packet fixed bit is zero")
def test_pull_long_header_too_short(self):
buf = Buffer(data=b"\xc0\x00")
with self.assertRaises(BufferReadError):
pull_quic_header(buf, host_cid_length=8)
def test_pull_short_header(self):
buf = Buffer(data=load("short_header.bin"))
header = pull_quic_header(buf, host_cid_length=8)
self.assertFalse(header.is_long_header)
self.assertEqual(header.version, None)
self.assertEqual(header.packet_type, 0x50)
self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6"))
self.assertEqual(header.source_cid, b"")
self.assertEqual(header.token, b"")
self.assertEqual(header.integrity_tag, b"")
self.assertEqual(header.rest_length, 12)
self.assertEqual(buf.tell(), 9)
def test_pull_short_header_no_fixed_bit(self):
buf = Buffer(data=b"\x00")
with self.assertRaises(ValueError) as cm:
pull_quic_header(buf, host_cid_length=8)
self.assertEqual(str(cm.exception), "Packet fixed bit is zero")
def test_encode_quic_version_negotiation(self):
data = encode_quic_version_negotiation(
destination_cid=binascii.unhexlify("9aac5a49ba87a849"),
source_cid=binascii.unhexlify("f92f4336fa951ba1"),
supported_versions=[0x45474716, QuicProtocolVersion.DRAFT_25],
)
self.assertEqual(data[1:], load("version_negotiation.bin")[1:])
class ParamsTest(TestCase):
maxDiff = None
def test_params(self):
data = binascii.unhexlify(
"010267100210cc2fd6e7d97a53ab5be85b28d75c8008030247e404048005fff"
"a05048000ffff06048000ffff0801060a01030b0119"
)
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(
params,
QuicTransportParameters(
idle_timeout=10000,
stateless_reset_token=b"\xcc/\xd6\xe7\xd9zS\xab[\xe8[(\xd7\\\x80\x08",
max_packet_size=2020,
initial_max_data=393210,
initial_max_stream_data_bidi_local=65535,
initial_max_stream_data_bidi_remote=65535,
initial_max_stream_data_uni=None,
initial_max_streams_bidi=6,
initial_max_streams_uni=None,
ack_delay_exponent=3,
max_ack_delay=25,
),
)
# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(len(buf.data), len(data))
def test_params_legacy(self):
data = binascii.unhexlify(
"004700020010cc2fd6e7d97a53ab5be85b28d75c80080008000106000100026"
"710000600048000ffff000500048000ffff000400048005fffa000a00010300"
"0b0001190003000247e4"
)
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(
params,
QuicTransportParameters(
idle_timeout=10000,
stateless_reset_token=b"\xcc/\xd6\xe7\xd9zS\xab[\xe8[(\xd7\\\x80\x08",
max_packet_size=2020,
initial_max_data=393210,
initial_max_stream_data_bidi_local=65535,
initial_max_stream_data_bidi_remote=65535,
initial_max_stream_data_uni=None,
initial_max_streams_bidi=6,
initial_max_streams_uni=None,
ack_delay_exponent=3,
max_ack_delay=25,
),
)
# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(len(buf.data), len(data))
def test_params_disable_active_migration(self):
data = binascii.unhexlify("0c00")
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(params, QuicTransportParameters(disable_active_migration=True))
# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(buf.data, data)
def test_params_disable_active_migration_legacy(self):
data = binascii.unhexlify("0004000c0000")
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(params, QuicTransportParameters(disable_active_migration=True))
# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(buf.data, data)
def test_params_preferred_address(self):
data = binascii.unhexlify(
"0d3b8ba27b8611532400890200000000f03c91fffe69a45411531262c4518d6"
"3013f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4"
)
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(
params,
QuicTransportParameters(
preferred_address=QuicPreferredAddress(
ipv4_address=("139.162.123.134", 4435),
ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435),
connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7",
stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4",
),
),
)
# serialize
buf = Buffer(capacity=1000)
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(buf.data, data)
def test_params_preferred_address_legacy(self):
data = binascii.unhexlify(
"003f000d003b8ba27b8611532400890200000000f03c91fffe69a4541153126"
"2c4518d63013f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48"
"ecb4"
)
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(
params,
QuicTransportParameters(
preferred_address=QuicPreferredAddress(
ipv4_address=("139.162.123.134", 4435),
ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435),
connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7",
stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4",
),
),
)
# serialize
buf = Buffer(capacity=len(data))
push_quic_transport_parameters(
buf, params, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(buf.data, data)
def test_params_unknown(self):
data = binascii.unhexlify("8000ff000100")
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_27
)
self.assertEqual(params, QuicTransportParameters())
def test_params_unknown_legacy(self):
# fb.mvfst.net sends a proprietary parameter 65280
data = binascii.unhexlify(
"006400050004800104000006000480010400000700048001040000040004801"
"0000000080008c0000000ffffffff00090008c0000000ffffffff0001000480"
"00ea60000a00010300030002500000020010616161616262626263636363646"
"46464ff00000100"
)
# parse
buf = Buffer(data=data)
params = pull_quic_transport_parameters(
buf, protocol_version=QuicProtocolVersion.DRAFT_25
)
self.assertEqual(
params,
QuicTransportParameters(
idle_timeout=60000,
stateless_reset_token=b"aaaabbbbccccdddd",
max_packet_size=4096,
initial_max_data=1048576,
initial_max_stream_data_bidi_local=66560,
initial_max_stream_data_bidi_remote=66560,
initial_max_stream_data_uni=66560,
initial_max_streams_bidi=4294967295,
initial_max_streams_uni=4294967295,
ack_delay_exponent=3,
),
)
def test_preferred_address_ipv4_only(self):
data = binascii.unhexlify(
"8ba27b8611530000000000000000000000000000000000001262c4518d63013"
"f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4"
)
# parse
buf = Buffer(data=data)
preferred_address = pull_quic_preferred_address(buf)
self.assertEqual(
preferred_address,
QuicPreferredAddress(
ipv4_address=("139.162.123.134", 4435),
ipv6_address=None,
connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7",
stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4",
),
)
# serialize
buf = Buffer(capacity=len(data))
push_quic_preferred_address(buf, preferred_address)
self.assertEqual(buf.data, data)
def test_preferred_address_ipv6_only(self):
data = binascii.unhexlify(
"0000000000002400890200000000f03c91fffe69a45411531262c4518d63013"
"f0c287ed3573efa9095603746b2e02d45480ba6643e5c6e7d48ecb4"
)
# parse
buf = Buffer(data=data)
preferred_address = pull_quic_preferred_address(buf)
self.assertEqual(
preferred_address,
QuicPreferredAddress(
ipv4_address=None,
ipv6_address=("2400:8902::f03c:91ff:fe69:a454", 4435),
connection_id=b"b\xc4Q\x8dc\x01?\x0c(~\xd3W>\xfa\x90\x95`7",
stateless_reset_token=b"F\xb2\xe0-EH\x0b\xa6d>\\n}H\xec\xb4",
),
)
# serialize
buf = Buffer(capacity=len(data))
push_quic_preferred_address(buf, preferred_address)
self.assertEqual(buf.data, data)
class FrameTest(TestCase):
def test_ack_frame(self):
data = b"\x00\x02\x00\x00"
# parse
buf = Buffer(data=data)
rangeset, delay = packet.pull_ack_frame(buf)
self.assertEqual(list(rangeset), [range(0, 1)])
self.assertEqual(delay, 2)
# serialize
buf = Buffer(capacity=len(data))
packet.push_ack_frame(buf, rangeset, delay)
self.assertEqual(buf.data, data)
def test_ack_frame_with_one_range(self):
data = b"\x02\x02\x01\x00\x00\x00"
# parse
buf = Buffer(data=data)
rangeset, delay = packet.pull_ack_frame(buf)
self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)])
self.assertEqual(delay, 2)
# serialize
buf = Buffer(capacity=len(data))
packet.push_ack_frame(buf, rangeset, delay)
self.assertEqual(buf.data, data)
def test_ack_frame_with_one_range_2(self):
data = b"\x05\x02\x01\x00\x00\x03"
# parse
buf = Buffer(data=data)
rangeset, delay = packet.pull_ack_frame(buf)
self.assertEqual(list(rangeset), [range(0, 4), range(5, 6)])
self.assertEqual(delay, 2)
# serialize
buf = Buffer(capacity=len(data))
packet.push_ack_frame(buf, rangeset, delay)
self.assertEqual(buf.data, data)
def test_ack_frame_with_one_range_3(self):
data = b"\x05\x02\x01\x00\x01\x02"
# parse
buf = Buffer(data=data)
rangeset, delay = packet.pull_ack_frame(buf)
self.assertEqual(list(rangeset), [range(0, 3), range(5, 6)])
self.assertEqual(delay, 2)
# serialize
buf = Buffer(capacity=len(data))
packet.push_ack_frame(buf, rangeset, delay)
self.assertEqual(buf.data, data)
def test_ack_frame_with_two_ranges(self):
data = b"\x04\x02\x02\x00\x00\x00\x00\x00"
# parse
buf = Buffer(data=data)
rangeset, delay = packet.pull_ack_frame(buf)
self.assertEqual(list(rangeset), [range(0, 1), range(2, 3), range(4, 5)])
self.assertEqual(delay, 2)
# serialize
buf = Buffer(capacity=len(data))
packet.push_ack_frame(buf, rangeset, delay)
self.assertEqual(buf.data, data)

Просмотреть файл

@ -0,0 +1,569 @@
from unittest import TestCase
from aioquic.quic.crypto import CryptoPair
from aioquic.quic.packet import (
PACKET_TYPE_HANDSHAKE,
PACKET_TYPE_INITIAL,
PACKET_TYPE_ONE_RTT,
QuicFrameType,
QuicProtocolVersion,
)
from aioquic.quic.packet_builder import (
QuicPacketBuilder,
QuicPacketBuilderStop,
QuicSentPacket,
)
from aioquic.tls import Epoch
def create_builder(is_client=False):
return QuicPacketBuilder(
host_cid=bytes(8),
is_client=is_client,
packet_number=0,
peer_cid=bytes(8),
peer_token=b"",
spin_bit=False,
version=QuicProtocolVersion.DRAFT_25,
)
def create_crypto():
crypto = CryptoPair()
crypto.setup_initial(bytes(8), is_client=True, version=QuicProtocolVersion.DRAFT_25)
return crypto
class QuicPacketBuilderTest(TestCase):
def test_long_header_empty(self):
builder = create_builder()
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 0)
self.assertEqual(packets, [])
# check builder
self.assertEqual(builder.packet_number, 0)
def test_long_header_padding(self):
builder = create_builder(is_client=True)
crypto = create_crypto()
# INITIAL, fully padded
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(100))
self.assertFalse(builder.packet_is_empty)
# INITIAL, empty
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 1280)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=1280,
)
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_long_header_initial_client_2(self):
builder = create_builder(is_client=True)
crypto = create_crypto()
# INITIAL, full length
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
# INITIAL
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(100))
self.assertFalse(builder.packet_is_empty)
# INITIAL, empty
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 2)
self.assertEqual(len(datagrams[0]), 1280)
self.assertEqual(len(datagrams[1]), 1280)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=1280,
),
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=1,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=1280,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 2)
def test_long_header_initial_server(self):
builder = create_builder()
crypto = create_crypto()
# INITIAL
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(100))
self.assertFalse(builder.packet_is_empty)
# INITIAL, empty
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 145)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=145,
)
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_long_header_then_short_header(self):
builder = create_builder()
crypto = create_crypto()
# INITIAL, full length
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
# INITIAL, empty
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertTrue(builder.packet_is_empty)
# ONE_RTT, full length
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 1253)
buf = builder.start_frame(QuicFrameType.STREAM_BASE)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
# ONE_RTT, empty
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 2)
self.assertEqual(len(datagrams[0]), 1280)
self.assertEqual(len(datagrams[1]), 1280)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=1280,
),
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=False,
packet_number=1,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1280,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 2)
def test_long_header_then_long_header(self):
builder = create_builder()
crypto = create_crypto()
# INITIAL
builder.start_packet(PACKET_TYPE_INITIAL, crypto)
self.assertEqual(builder.remaining_flight_space, 1236)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(199))
self.assertFalse(builder.packet_is_empty)
# HANDSHAKE
builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto)
self.assertEqual(builder.remaining_flight_space, 993)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(299))
self.assertFalse(builder.packet_is_empty)
# ONE_RTT
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 666)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(299))
self.assertFalse(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 914)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=244,
),
QuicSentPacket(
epoch=Epoch.HANDSHAKE,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=1,
packet_type=PACKET_TYPE_HANDSHAKE,
sent_bytes=343,
),
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=2,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=327,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 3)
def test_short_header_empty(self):
builder = create_builder()
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 1253)
self.assertTrue(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(datagrams, [])
self.assertEqual(packets, [])
# check builder
self.assertEqual(builder.packet_number, 0)
def test_short_header_padding(self):
builder = create_builder()
crypto = create_crypto()
# ONE_RTT, full length
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 1253)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 1280)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1280,
)
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_short_header_max_flight_bytes(self):
"""
max_flight_bytes limits sent data.
"""
builder = create_builder()
builder.max_flight_bytes = 1000
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 973)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
builder.start_frame(QuicFrameType.CRYPTO)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 1000)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1000,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_short_header_max_flight_bytes_zero(self):
"""
max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE.
Check CRYPTO is not allowed.
"""
builder = create_builder()
builder.max_flight_bytes = 0
crypto = create_crypto()
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
builder.start_frame(QuicFrameType.CRYPTO)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 0)
# check builder
self.assertEqual(builder.packet_number, 0)
def test_short_header_max_flight_bytes_zero_ack(self):
"""
max_flight_bytes = 0 only allows ACKs and CONNECTION_CLOSE.
Check ACK is allowed.
"""
builder = create_builder()
builder.max_flight_bytes = 0
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
buf = builder.start_frame(QuicFrameType.ACK)
buf.push_bytes(bytes(64))
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
builder.start_frame(QuicFrameType.CRYPTO)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 92)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=92,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_short_header_max_total_bytes_1(self):
"""
max_total_bytes doesn't allow any packets.
"""
builder = create_builder()
builder.max_total_bytes = 11
crypto = create_crypto()
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(datagrams, [])
self.assertEqual(packets, [])
# check builder
self.assertEqual(builder.packet_number, 0)
def test_short_header_max_total_bytes_2(self):
"""
max_total_bytes allows a short packet.
"""
builder = create_builder()
builder.max_total_bytes = 800
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 773)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 1)
self.assertEqual(len(datagrams[0]), 800)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=800,
)
],
)
# check builder
self.assertEqual(builder.packet_number, 1)
def test_short_header_max_total_bytes_3(self):
builder = create_builder()
builder.max_total_bytes = 2000
crypto = create_crypto()
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 1253)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
self.assertEqual(builder.remaining_flight_space, 693)
buf = builder.start_frame(QuicFrameType.CRYPTO)
buf.push_bytes(bytes(builder.remaining_flight_space))
self.assertFalse(builder.packet_is_empty)
with self.assertRaises(QuicPacketBuilderStop):
builder.start_packet(PACKET_TYPE_ONE_RTT, crypto)
# check datagrams
datagrams, packets = builder.flush()
self.assertEqual(len(datagrams), 2)
self.assertEqual(len(datagrams[0]), 1280)
self.assertEqual(len(datagrams[1]), 720)
self.assertEqual(
packets,
[
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1280,
),
QuicSentPacket(
epoch=Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=1,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=720,
),
],
)
# check builder
self.assertEqual(builder.packet_number, 2)

Просмотреть файл

@ -0,0 +1,237 @@
from unittest import TestCase
from aioquic.quic.rangeset import RangeSet
class RangeSetTest(TestCase):
def test_add_single_duplicate(self):
rangeset = RangeSet()
rangeset.add(0)
self.assertEqual(list(rangeset), [range(0, 1)])
rangeset.add(0)
self.assertEqual(list(rangeset), [range(0, 1)])
def test_add_single_ordered(self):
rangeset = RangeSet()
rangeset.add(0)
self.assertEqual(list(rangeset), [range(0, 1)])
rangeset.add(1)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(2)
self.assertEqual(list(rangeset), [range(0, 3)])
def test_add_single_merge(self):
rangeset = RangeSet()
rangeset.add(0)
self.assertEqual(list(rangeset), [range(0, 1)])
rangeset.add(2)
self.assertEqual(list(rangeset), [range(0, 1), range(2, 3)])
rangeset.add(1)
self.assertEqual(list(rangeset), [range(0, 3)])
def test_add_single_reverse(self):
rangeset = RangeSet()
rangeset.add(2)
self.assertEqual(list(rangeset), [range(2, 3)])
rangeset.add(1)
self.assertEqual(list(rangeset), [range(1, 3)])
rangeset.add(0)
self.assertEqual(list(rangeset), [range(0, 3)])
def test_add_range_ordered(self):
rangeset = RangeSet()
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(2, 4)
self.assertEqual(list(rangeset), [range(0, 4)])
rangeset.add(4, 6)
self.assertEqual(list(rangeset), [range(0, 6)])
def test_add_range_merge(self):
rangeset = RangeSet()
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(3, 5)
self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)])
rangeset.add(2, 3)
self.assertEqual(list(rangeset), [range(0, 5)])
def test_add_range_overlap(self):
rangeset = RangeSet()
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(3, 5)
self.assertEqual(list(rangeset), [range(0, 2), range(3, 5)])
rangeset.add(1, 5)
self.assertEqual(list(rangeset), [range(0, 5)])
def test_add_range_overlap_2(self):
rangeset = RangeSet()
rangeset.add(2, 4)
rangeset.add(6, 8)
rangeset.add(10, 12)
rangeset.add(16, 18)
self.assertEqual(
list(rangeset), [range(2, 4), range(6, 8), range(10, 12), range(16, 18)]
)
rangeset.add(1, 15)
self.assertEqual(list(rangeset), [range(1, 15), range(16, 18)])
def test_add_range_reverse(self):
rangeset = RangeSet()
rangeset.add(6, 8)
self.assertEqual(list(rangeset), [range(6, 8)])
rangeset.add(3, 5)
self.assertEqual(list(rangeset), [range(3, 5), range(6, 8)])
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)])
def test_add_range_unordered_contiguous(self):
rangeset = RangeSet()
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(4, 6)
self.assertEqual(list(rangeset), [range(0, 2), range(4, 6)])
rangeset.add(2, 4)
self.assertEqual(list(rangeset), [range(0, 6)])
def test_add_range_unordered_sparse(self):
rangeset = RangeSet()
rangeset.add(0, 2)
self.assertEqual(list(rangeset), [range(0, 2)])
rangeset.add(6, 8)
self.assertEqual(list(rangeset), [range(0, 2), range(6, 8)])
rangeset.add(3, 5)
self.assertEqual(list(rangeset), [range(0, 2), range(3, 5), range(6, 8)])
def test_subtract(self):
rangeset = RangeSet()
rangeset.add(0, 10)
rangeset.add(20, 30)
rangeset.subtract(0, 3)
self.assertEqual(list(rangeset), [range(3, 10), range(20, 30)])
def test_subtract_no_change(self):
rangeset = RangeSet()
rangeset.add(5, 10)
rangeset.add(15, 20)
rangeset.add(25, 30)
rangeset.subtract(0, 5)
self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)])
rangeset.subtract(10, 15)
self.assertEqual(list(rangeset), [range(5, 10), range(15, 20), range(25, 30)])
def test_subtract_overlap(self):
rangeset = RangeSet()
rangeset.add(1, 4)
rangeset.add(6, 8)
rangeset.add(10, 20)
rangeset.add(30, 40)
self.assertEqual(
list(rangeset), [range(1, 4), range(6, 8), range(10, 20), range(30, 40)]
)
rangeset.subtract(0, 2)
self.assertEqual(
list(rangeset), [range(2, 4), range(6, 8), range(10, 20), range(30, 40)]
)
rangeset.subtract(3, 11)
self.assertEqual(list(rangeset), [range(2, 3), range(11, 20), range(30, 40)])
def test_subtract_split(self):
rangeset = RangeSet()
rangeset.add(0, 10)
rangeset.subtract(2, 5)
self.assertEqual(list(rangeset), [range(0, 2), range(5, 10)])
def test_bool(self):
with self.assertRaises(NotImplementedError):
bool(RangeSet())
def test_contains(self):
rangeset = RangeSet()
self.assertFalse(0 in rangeset)
rangeset = RangeSet([range(0, 1)])
self.assertTrue(0 in rangeset)
self.assertFalse(1 in rangeset)
rangeset = RangeSet([range(0, 1), range(3, 6)])
self.assertTrue(0 in rangeset)
self.assertFalse(1 in rangeset)
self.assertFalse(2 in rangeset)
self.assertTrue(3 in rangeset)
self.assertTrue(4 in rangeset)
self.assertTrue(5 in rangeset)
self.assertFalse(6 in rangeset)
def test_eq(self):
r0 = RangeSet([range(0, 1)])
r1 = RangeSet([range(1, 2), range(3, 4)])
r2 = RangeSet([range(3, 4), range(1, 2)])
self.assertTrue(r0 == r0)
self.assertFalse(r0 == r1)
self.assertFalse(r0 == 0)
self.assertTrue(r1 == r1)
self.assertFalse(r1 == r0)
self.assertTrue(r1 == r2)
self.assertFalse(r1 == 0)
self.assertTrue(r2 == r2)
self.assertTrue(r2 == r1)
self.assertFalse(r2 == r0)
self.assertFalse(r2 == 0)
def test_len(self):
rangeset = RangeSet()
self.assertEqual(len(rangeset), 0)
rangeset = RangeSet([range(0, 1)])
self.assertEqual(len(rangeset), 1)
def test_pop(self):
rangeset = RangeSet([range(1, 2), range(3, 4)])
r = rangeset.shift()
self.assertEqual(r, range(1, 2))
self.assertEqual(list(rangeset), [range(3, 4)])
def test_repr(self):
rangeset = RangeSet([range(1, 2), range(3, 4)])
self.assertEqual(repr(rangeset), "RangeSet([range(1, 2), range(3, 4)])")

Просмотреть файл

@ -0,0 +1,227 @@
import math
from unittest import TestCase
from aioquic import tls
from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT
from aioquic.quic.packet_builder import QuicSentPacket
from aioquic.quic.rangeset import RangeSet
from aioquic.quic.recovery import (
QuicPacketPacer,
QuicPacketRecovery,
QuicPacketSpace,
QuicRttMonitor,
)
def send_probe():
pass
class QuicPacketPacerTest(TestCase):
def setUp(self):
self.pacer = QuicPacketPacer()
def test_no_measurement(self):
self.assertIsNone(self.pacer.next_send_time(now=0.0))
self.pacer.update_after_send(now=0.0)
self.assertIsNone(self.pacer.next_send_time(now=0.0))
self.pacer.update_after_send(now=0.0)
def test_with_measurement(self):
self.assertIsNone(self.pacer.next_send_time(now=0.0))
self.pacer.update_after_send(now=0.0)
self.pacer.update_rate(congestion_window=1280000, smoothed_rtt=0.05)
self.assertEqual(self.pacer.bucket_max, 0.0008)
self.assertEqual(self.pacer.bucket_time, 0.0)
self.assertEqual(self.pacer.packet_time, 0.00005)
# 16 packets
for i in range(16):
self.assertIsNone(self.pacer.next_send_time(now=1.0))
self.pacer.update_after_send(now=1.0)
self.assertAlmostEqual(self.pacer.next_send_time(now=1.0), 1.00005)
# 2 packets
for i in range(2):
self.assertIsNone(self.pacer.next_send_time(now=1.00005))
self.pacer.update_after_send(now=1.00005)
self.assertAlmostEqual(self.pacer.next_send_time(now=1.00005), 1.0001)
# 1 packet
self.assertIsNone(self.pacer.next_send_time(now=1.0001))
self.pacer.update_after_send(now=1.0001)
self.assertAlmostEqual(self.pacer.next_send_time(now=1.0001), 1.00015)
# 2 packets
for i in range(2):
self.assertIsNone(self.pacer.next_send_time(now=1.00015))
self.pacer.update_after_send(now=1.00015)
self.assertAlmostEqual(self.pacer.next_send_time(now=1.00015), 1.0002)
class QuicPacketRecoveryTest(TestCase):
def setUp(self):
self.INITIAL_SPACE = QuicPacketSpace()
self.HANDSHAKE_SPACE = QuicPacketSpace()
self.ONE_RTT_SPACE = QuicPacketSpace()
self.recovery = QuicPacketRecovery(
is_client_without_1rtt=False, send_probe=send_probe
)
self.recovery.spaces = [
self.INITIAL_SPACE,
self.HANDSHAKE_SPACE,
self.ONE_RTT_SPACE,
]
def test_discard_space(self):
self.recovery.discard_space(self.INITIAL_SPACE)
def test_on_ack_received_ack_eliciting(self):
packet = QuicSentPacket(
epoch=tls.Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=False,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1280,
sent_time=0.0,
)
space = self.ONE_RTT_SPACE
#  packet sent
self.recovery.on_packet_sent(packet, space)
self.assertEqual(self.recovery.bytes_in_flight, 1280)
self.assertEqual(space.ack_eliciting_in_flight, 1)
self.assertEqual(len(space.sent_packets), 1)
# packet ack'd
self.recovery.on_ack_received(
space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0
)
self.assertEqual(self.recovery.bytes_in_flight, 0)
self.assertEqual(space.ack_eliciting_in_flight, 0)
self.assertEqual(len(space.sent_packets), 0)
# check RTT
self.assertTrue(self.recovery._rtt_initialized)
self.assertEqual(self.recovery._rtt_latest, 10.0)
self.assertEqual(self.recovery._rtt_min, 10.0)
self.assertEqual(self.recovery._rtt_smoothed, 10.0)
def test_on_ack_received_non_ack_eliciting(self):
packet = QuicSentPacket(
epoch=tls.Epoch.ONE_RTT,
in_flight=True,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=0,
packet_type=PACKET_TYPE_ONE_RTT,
sent_bytes=1280,
sent_time=123.45,
)
space = self.ONE_RTT_SPACE
#  packet sent
self.recovery.on_packet_sent(packet, space)
self.assertEqual(self.recovery.bytes_in_flight, 1280)
self.assertEqual(space.ack_eliciting_in_flight, 0)
self.assertEqual(len(space.sent_packets), 1)
# packet ack'd
self.recovery.on_ack_received(
space, ack_rangeset=RangeSet([range(0, 1)]), ack_delay=0.0, now=10.0
)
self.assertEqual(self.recovery.bytes_in_flight, 0)
self.assertEqual(space.ack_eliciting_in_flight, 0)
self.assertEqual(len(space.sent_packets), 0)
# check RTT
self.assertFalse(self.recovery._rtt_initialized)
self.assertEqual(self.recovery._rtt_latest, 0.0)
self.assertEqual(self.recovery._rtt_min, math.inf)
self.assertEqual(self.recovery._rtt_smoothed, 0.0)
def test_on_packet_lost_crypto(self):
packet = QuicSentPacket(
epoch=tls.Epoch.INITIAL,
in_flight=True,
is_ack_eliciting=True,
is_crypto_packet=True,
packet_number=0,
packet_type=PACKET_TYPE_INITIAL,
sent_bytes=1280,
sent_time=0.0,
)
space = self.INITIAL_SPACE
self.recovery.on_packet_sent(packet, space)
self.assertEqual(self.recovery.bytes_in_flight, 1280)
self.assertEqual(space.ack_eliciting_in_flight, 1)
self.assertEqual(len(space.sent_packets), 1)
self.recovery._detect_loss(space, now=1.0)
self.assertEqual(self.recovery.bytes_in_flight, 0)
self.assertEqual(space.ack_eliciting_in_flight, 0)
self.assertEqual(len(space.sent_packets), 0)
class QuicRttMonitorTest(TestCase):
def test_monitor(self):
monitor = QuicRttMonitor()
self.assertFalse(monitor.is_rtt_increasing(rtt=10, now=1000))
self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0])
self.assertFalse(monitor._ready)
# not taken into account
self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1000))
self.assertEqual(monitor._samples, [10, 0.0, 0.0, 0.0, 0.0])
self.assertFalse(monitor._ready)
self.assertFalse(monitor.is_rtt_increasing(rtt=11, now=1001))
self.assertEqual(monitor._samples, [10, 11, 0.0, 0.0, 0.0])
self.assertFalse(monitor._ready)
self.assertFalse(monitor.is_rtt_increasing(rtt=12, now=1002))
self.assertEqual(monitor._samples, [10, 11, 12, 0.0, 0.0])
self.assertFalse(monitor._ready)
self.assertFalse(monitor.is_rtt_increasing(rtt=13, now=1003))
self.assertEqual(monitor._samples, [10, 11, 12, 13, 0.0])
self.assertFalse(monitor._ready)
# we now have enough samples
self.assertFalse(monitor.is_rtt_increasing(rtt=14, now=1004))
self.assertEqual(monitor._samples, [10, 11, 12, 13, 14])
self.assertTrue(monitor._ready)
self.assertFalse(monitor.is_rtt_increasing(rtt=20, now=1005))
self.assertEqual(monitor._increases, 0)
self.assertFalse(monitor.is_rtt_increasing(rtt=30, now=1006))
self.assertEqual(monitor._increases, 0)
self.assertFalse(monitor.is_rtt_increasing(rtt=40, now=1007))
self.assertEqual(monitor._increases, 0)
self.assertFalse(monitor.is_rtt_increasing(rtt=50, now=1008))
self.assertEqual(monitor._increases, 0)
self.assertFalse(monitor.is_rtt_increasing(rtt=60, now=1009))
self.assertEqual(monitor._increases, 1)
self.assertFalse(monitor.is_rtt_increasing(rtt=70, now=1010))
self.assertEqual(monitor._increases, 2)
self.assertFalse(monitor.is_rtt_increasing(rtt=80, now=1011))
self.assertEqual(monitor._increases, 3)
self.assertFalse(monitor.is_rtt_increasing(rtt=90, now=1012))
self.assertEqual(monitor._increases, 4)
self.assertTrue(monitor.is_rtt_increasing(rtt=100, now=1013))
self.assertEqual(monitor._increases, 5)

Просмотреть файл

@ -0,0 +1,30 @@
from unittest import TestCase
from aioquic.quic.retry import QuicRetryTokenHandler
class QuicRetryTokenHandlerTest(TestCase):
def test_retry_token(self):
addr = ("127.0.0.1", 1234)
cid = b"\x08\x07\x06\05\x04\x03\x02\x01"
handler = QuicRetryTokenHandler()
# create token
token = handler.create_token(addr, cid)
self.assertIsNotNone(token)
# validate token - ok
self.assertEqual(handler.validate_token(addr, token), cid)
# validate token - empty
with self.assertRaises(ValueError) as cm:
handler.validate_token(addr, b"")
self.assertEqual(
str(cm.exception), "Ciphertext length must be equal to key size."
)
# validate token - wrong address
with self.assertRaises(ValueError) as cm:
handler.validate_token(("1.2.3.4", 12345), token)
self.assertEqual(str(cm.exception), "Remote address does not match.")

Просмотреть файл

@ -0,0 +1,446 @@
from unittest import TestCase
from aioquic.quic.events import StreamDataReceived
from aioquic.quic.packet import QuicStreamFrame
from aioquic.quic.packet_builder import QuicDeliveryState
from aioquic.quic.stream import QuicStream
class QuicStreamTest(TestCase):
def test_recv_empty(self):
stream = QuicStream(stream_id=0)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 0)
# empty
self.assertEqual(stream.add_frame(QuicStreamFrame(offset=0, data=b"")), None)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 0)
def test_recv_ordered(self):
stream = QuicStream(stream_id=0)
# add data at start
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 8)
# add more data
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345")),
StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 16)
# add data and fin
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=16, data=b"67890123", fin=True)),
StreamDataReceived(data=b"67890123", end_stream=True, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 24)
def test_recv_unordered(self):
stream = QuicStream(stream_id=0)
# add data at offset 8
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345")), None
)
self.assertEqual(
bytes(stream._recv_buffer), b"\x00\x00\x00\x00\x00\x00\x00\x0089012345"
)
self.assertEqual(list(stream._recv_ranges), [range(8, 16)])
self.assertEqual(stream._recv_buffer_start, 0)
# add data at offset 0
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"0123456789012345", end_stream=False, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 16)
def test_recv_offset_only(self):
stream = QuicStream(stream_id=0)
# add data at offset 0
self.assertEqual(stream.add_frame(QuicStreamFrame(offset=0, data=b"")), None)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 0)
# add data at offset 8
self.assertEqual(stream.add_frame(QuicStreamFrame(offset=8, data=b"")), None)
self.assertEqual(
bytes(stream._recv_buffer), b"\x00\x00\x00\x00\x00\x00\x00\x00"
)
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 0)
def test_recv_already_fully_consumed(self):
stream = QuicStream(stream_id=0)
# add data at offset 0
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 8)
# add data again at offset 0
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")), None
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 8)
def test_recv_already_partially_consumed(self):
stream = QuicStream(stream_id=0)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0),
)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"0123456789012345")),
StreamDataReceived(data=b"89012345", end_stream=False, stream_id=0),
)
self.assertEqual(bytes(stream._recv_buffer), b"")
self.assertEqual(list(stream._recv_ranges), [])
self.assertEqual(stream._recv_buffer_start, 16)
def test_recv_fin(self):
stream = QuicStream(stream_id=0)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0),
)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)),
StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0),
)
def test_recv_fin_out_of_order(self):
stream = QuicStream(stream_id=0)
# add data at offset 8 with FIN
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)),
None,
)
# add data at offset 0
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"0123456789012345", end_stream=True, stream_id=0),
)
def test_recv_fin_then_data(self):
stream = QuicStream(stream_id=0)
stream.add_frame(QuicStreamFrame(offset=0, data=b"", fin=True))
with self.assertRaises(Exception) as cm:
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567"))
self.assertEqual(str(cm.exception), "Data received beyond FIN")
def test_recv_fin_twice(self):
stream = QuicStream(stream_id=0)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"01234567")),
StreamDataReceived(data=b"01234567", end_stream=False, stream_id=0),
)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)),
StreamDataReceived(data=b"89012345", end_stream=True, stream_id=0),
)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=8, data=b"89012345", fin=True)),
StreamDataReceived(data=b"", end_stream=True, stream_id=0),
)
def test_recv_fin_without_data(self):
stream = QuicStream(stream_id=0)
self.assertEqual(
stream.add_frame(QuicStreamFrame(offset=0, data=b"", fin=True)),
StreamDataReceived(data=b"", end_stream=True, stream_id=0),
)
def test_send_data(self):
stream = QuicStream()
self.assertEqual(stream.next_send_offset, 0)
# nothing to send yet
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write data
stream.write(b"0123456789012345")
self.assertEqual(list(stream._send_pending), [range(0, 16)])
self.assertEqual(stream.next_send_offset, 0)
# send a chunk
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"01234567")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 0)
self.assertEqual(list(stream._send_pending), [range(8, 16)])
self.assertEqual(stream.next_send_offset, 8)
# send another chunk
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"89012345")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 8)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# nothing more to send
frame = stream.get_frame(8)
self.assertIsNone(frame)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# first chunk gets acknowledged
stream.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
# second chunk gets acknowledged
stream.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
def test_send_data_and_fin(self):
stream = QuicStream()
# nothing to send yet
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write data and EOF
stream.write(b"0123456789012345", end_stream=True)
self.assertEqual(list(stream._send_pending), [range(0, 16)])
self.assertEqual(stream.next_send_offset, 0)
# send a chunk
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"01234567")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 0)
self.assertEqual(stream.next_send_offset, 8)
# send another chunk
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"89012345")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 8)
self.assertEqual(stream.next_send_offset, 16)
# nothing more to send
frame = stream.get_frame(8)
self.assertIsNone(frame)
self.assertEqual(stream.next_send_offset, 16)
def test_send_data_lost(self):
stream = QuicStream()
# nothing to send yet
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write data and EOF
stream.write(b"0123456789012345", end_stream=True)
self.assertEqual(list(stream._send_pending), [range(0, 16)])
self.assertEqual(stream.next_send_offset, 0)
# send a chunk
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0)
)
self.assertEqual(list(stream._send_pending), [range(8, 16)])
self.assertEqual(stream.next_send_offset, 8)
# send another chunk
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8)
)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# nothing more to send
self.assertIsNone(stream.get_frame(8))
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# a chunk gets lost
stream.on_data_delivery(QuicDeliveryState.LOST, 0, 8)
self.assertEqual(list(stream._send_pending), [range(0, 8)])
self.assertEqual(stream.next_send_offset, 0)
# send chunk again
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0)
)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
def test_send_data_lost_fin(self):
stream = QuicStream()
# nothing to send yet
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write data and EOF
stream.write(b"0123456789012345", end_stream=True)
self.assertEqual(list(stream._send_pending), [range(0, 16)])
self.assertEqual(stream.next_send_offset, 0)
# send a chunk
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"01234567", fin=False, offset=0)
)
self.assertEqual(list(stream._send_pending), [range(8, 16)])
self.assertEqual(stream.next_send_offset, 8)
# send another chunk
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8)
)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# nothing more to send
self.assertIsNone(stream.get_frame(8))
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
# a chunk gets lost
stream.on_data_delivery(QuicDeliveryState.LOST, 8, 16)
self.assertEqual(list(stream._send_pending), [range(8, 16)])
self.assertEqual(stream.next_send_offset, 8)
# send chunk again
self.assertEqual(
stream.get_frame(8), QuicStreamFrame(data=b"89012345", fin=True, offset=8)
)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 16)
def test_send_blocked(self):
stream = QuicStream()
max_offset = 12
# nothing to send yet
frame = stream.get_frame(8, max_offset)
self.assertIsNone(frame)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 0)
# write data, send a chunk
stream.write(b"0123456789012345")
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"01234567")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 0)
self.assertEqual(list(stream._send_pending), [range(8, 16)])
self.assertEqual(stream.next_send_offset, 8)
# send is limited by peer
frame = stream.get_frame(8, max_offset)
self.assertEqual(frame.data, b"8901")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 8)
self.assertEqual(list(stream._send_pending), [range(12, 16)])
self.assertEqual(stream.next_send_offset, 12)
# unable to send, blocked
frame = stream.get_frame(8, max_offset)
self.assertIsNone(frame)
self.assertEqual(list(stream._send_pending), [range(12, 16)])
self.assertEqual(stream.next_send_offset, 12)
# write more data, still blocked
stream.write(b"abcdefgh")
frame = stream.get_frame(8, max_offset)
self.assertIsNone(frame)
self.assertEqual(list(stream._send_pending), [range(12, 24)])
self.assertEqual(stream.next_send_offset, 12)
# peer raises limit, send some data
max_offset += 8
frame = stream.get_frame(8, max_offset)
self.assertEqual(frame.data, b"2345abcd")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 12)
self.assertEqual(list(stream._send_pending), [range(20, 24)])
self.assertEqual(stream.next_send_offset, 20)
# peer raises limit again, send remaining data
max_offset += 8
frame = stream.get_frame(8, max_offset)
self.assertEqual(frame.data, b"efgh")
self.assertFalse(frame.fin)
self.assertEqual(frame.offset, 20)
self.assertEqual(list(stream._send_pending), [])
self.assertEqual(stream.next_send_offset, 24)
# nothing more to send
frame = stream.get_frame(8, max_offset)
self.assertIsNone(frame)
def test_send_fin_only(self):
stream = QuicStream()
# nothing to send yet
self.assertTrue(stream.send_buffer_is_empty)
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write EOF
stream.write(b"", end_stream=True)
self.assertFalse(stream.send_buffer_is_empty)
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 0)
# nothing more to send
self.assertFalse(stream.send_buffer_is_empty) # FIXME?
frame = stream.get_frame(8)
self.assertIsNone(frame)
self.assertTrue(stream.send_buffer_is_empty)
def test_send_fin_only_despite_blocked(self):
stream = QuicStream()
# nothing to send yet
self.assertTrue(stream.send_buffer_is_empty)
frame = stream.get_frame(8)
self.assertIsNone(frame)
# write EOF
stream.write(b"", end_stream=True)
self.assertFalse(stream.send_buffer_is_empty)
frame = stream.get_frame(8)
self.assertEqual(frame.data, b"")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 0)
# nothing more to send
self.assertFalse(stream.send_buffer_is_empty) # FIXME?
frame = stream.get_frame(8)
self.assertIsNone(frame)
self.assertTrue(stream.send_buffer_is_empty)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_certificate_verify.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_alpn.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_psk.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_client_hello_with_sni.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_encrypted_extensions_with_alpn.bin поставляемый Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
testing/web-platform/tests/tools/third_party/aioquic/tests/tls_finished.bin поставляемый Normal file

Двоичный файл не отображается.

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше