Bug 1695263
- Vendor in a copy of wptserve that's still Python 2 compatible, r=marionette-reviewers,whimboo
Upstream wptserve just switched to Python 3 only. That's fine for web-platform-tests, but it turns out that some marionette harness tests are also using wptserve and are still on Python 2. Since fixing marionette harness turns out to be non-trivial and this blocks other wpt work, this patch does the following: * Temporarily vendors the last wptserve revision that works with Python 2 in to testing/web-platform/mozilla/tests/tools/wptserve_py2 * Configures the mach virtualenv to use that copy for Python 2 modules only. * Configures the test packaging system to also put that copy in the common tests zip. Requirements files are updated to use either the Python 2 version or the Pyhton 3 version as required. Differential Revision: https://phabricator.services.mozilla.com/D106764
@ -73,6 +73,7 @@ exclude =
@ -34,7 +34,8 @@ mozilla.pth:testing/web-platform/tests/tools/third_party/html5lib
@ -170,6 +170,12 @@ ARCHIVE_FILES = {
"pattern": "**",
"dest": "tps/tests",
"source": buildconfig.topsrcdir,
"base": "testing/web-platform/mozilla/tests/tools/wptserve_py2",
"pattern": "**",
"dest": "tools/wptserve_py2",
"source": buildconfig.topsrcdir,
"base": "testing/web-platform/tests/tools/wptserve",
@ -577,7 +583,6 @@ ARCHIVE_FILES = {
"base": "testing",
"pattern": "web-platform/tests/**",
"ignore": [
@ -1,6 +1,7 @@
-r mozbase_requirements.txt
../tools/wptserve_py2 ; python_version < '3'
../tools/wptserve ; python_version >= '3'
../tools/wpt_third_party/enum ; python_version < '3'
@ -2,7 +2,8 @@
-r mozbase_source_requirements.txt
../web-platform/tests/tools/wptserve ; python_version >= '3'
../web-platform/mozilla/tests/tools/wptserve_py2 ; python_version < '3'
../web-platform/tests/tools/third_party/enum ; python_version < '3'
@ -0,0 +1,11 @@
branch = True
parallel = True
omit =
wptserve =
@ -0,0 +1,40 @@
# C extensions
# Packages
# Installer logs
# Unit test / coverage reports
# Translations
# Mr Developer
@ -0,0 +1,11 @@
# The 3-Clause BSD License
Copyright 2019 web-platform-tests contributors
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
@ -0,0 +1 @@
include README.md
@ -0,0 +1,8 @@
Python 2 compatible version of wptserve.
This should be removed as soon as Python 2 is no longer required by
any code that uses wptserve (notably marionette-harness based tests).
When removing it reverting the commit that added it should ensure that
we return to the previous configuration where wptserve is used from
the wpt import.
@ -0,0 +1,153 @@
testing the web platform. This means that extreme flexibility —
including the possibility of HTTP non-conformance — in the response is
.. toctree::
:maxdepth: 2
@ -0,0 +1,51 @@
wptserve has been designed with the specific goal of making a server
that is suitable for writing tests for the web platform. This means
that it cannot use common abstractions over HTTP such as WSGI, since
these assume that the goal is to generate a well-formed HTTP
response. Testcases, however, often require precise control of the
exact bytes sent over the wire and their timing. The full list of
design goals for the server are:
* Suitable to run on individual test machines and over the public internet.
* Support plain TCP and SSL servers.
* Serve static files with the minimum of configuration.
* Allow headers to be overwritten on a per-file and per-directory
* Full customisation of headers sent (e.g. altering or omitting
"mandatory" headers).
* Simple per-client state.
* Complex logic in tests, up to precise control over the individual
bytes sent and the timing of sending them.
Request Handling
At the high level, the design of the server is based around similar
concepts to those found in common web frameworks like Django, Pyramid
or Flask. In particular the lifecycle of a typical request will be
familiar to users of these systems. Incoming requests are parsed and a
:doc:`Request <request>` object is constructed. This object is passed
to a :ref:`Router <router.Interface>` instance, which is
responsible for mapping the request method and path to a handler
function. This handler is passed two arguments; the request object and
a :doc:`Response <response>` object. In cases where only simple
responses are required, the handler function may fill in the
properties of the response object and the server will take care of
constructing the response. However each Response also contains a
:ref:`ResponseWriter <response.Interface>` which can be
used to directly control the TCP socket.
By default there are several built-in handler functions that provide a
higher level API than direct manipulation of the Response
object. These are documented in :doc:`handlers`.
@ -0,0 +1,8 @@
:mod:`Interface <wptserve.pipes>`
.. automodule:: wptserve.pipes
@ -0,0 +1,10 @@
Request object.
:mod:`Interface <wptserve.request>`
.. automodule:: wptserve.request
@ -0,0 +1,41 @@
Response object. This object is used to control the response that will
be sent to the HTTP client. A handler function will take the response
object and fill in various parts of the response. For example, a plain
text response with the body 'Some example content' could be produced as::
def handler(request, response):
response.headers.set("Content-Type", "text/plain")
response.content = "Some example content"
The response object also gives access to a ResponseWriter, which
allows direct access to the response socket. For example, one could
write a similar response but with more explicit control as follows::
import time
def handler(request, response):
response.add_required_headers = False # Don't implicitly add HTTP headers
response.writer.write_header("Content-Type", "text/plain")
response.writer.write_header("Content-Length", len("Some example content"))
response.writer.write("Some ")
response.writer.write("example content")
Note that when writing the response directly like this it is always
necessary to either set the Content-Length header or set
`response.close_connection = True`. Without one of these, the client
will not be able to determine where the response body ends and will
continue to load indefinitely.
.. _response.Interface:
:mod:`Interface <wptserve.response>`
.. automodule:: wptserve.response
@ -0,0 +1,78 @@
The router is used to match incoming requests to request handler
functions. Typically users don't interact with the router directly,
but instead send a list of routes to register when starting the
server. However it is also possible to add routes after starting the
server by calling the `register` method on the server's `router`
Routes are represented by a three item tuple::
(methods, path_match, handler)
`methods` is either a string or a list of strings indicating the HTTP
methods to match. In cases where all methods should match there is a
special sentinel value `any_method` provided as a property of the
`router` module that can be used.
`path_match` is an expression that will be evaluated against the
request path to decide if the handler should match. These expressions
follow a custom syntax intended to make matching URLs straightforward
and, in particular, to be easier to use than raw regexp for URL
matching. There are three possible components of a match expression:
* Literals. These match any character. The special characters \*, \{
and \} must be escaped by prefixing them with a \\.
* Match groups. These match any character other than / and save the
result as a named group. They are delimited by curly braces; for
would create a match group with the name `abc`.
* Stars. These are denoted with a `*` and match any character
including /. There can be at most one star
per pattern and it must follow any match groups.
Path expressions always match the entire request path and a leading /
in the expression is implied even if it is not explicitly
provided. This means that `/foo` and `foo` are equivalent.
For example, the following pattern matches all requests for resources with the
extension `.py`::
The following expression matches anything directly under `/resources`
with a `.html` extension, and places the "filename" in the `name`
The groups, including anything that matches a `*` are available in the
request object through the `route_match` property. This is a
dictionary mapping the group names, and any match for `*` to the
matching part of the route. For example, given a route::
and the request path `/api/test/html/test.html`, `route_match` would
{"sub_api": "html", "*": "html/test.html"}
`handler` is a function taking a request and a response object that is
responsible for constructing the response to the HTTP request. See
:doc:`handlers` for more details on handler functions.
.. _router.Interface:
:mod:`Interface <wptserve.router>`
.. automodule:: wptserve.router
@ -0,0 +1,20 @@
Basic server classes and router.
The following example creates a server that serves static files from
the `files` subdirectory of the current directory and causes it to
run on port 8080 until it is killed::
from wptserve import server, handlers
httpd = server.WebTestHttpd(port=8080, doc_root="./files/",
routes=[("GET", "*", handlers.file_handler)])
:mod:`Interface <wptserve.server>`
.. automodule:: wptserve.server
@ -0,0 +1,31 @@
Object for storing cross-request state. This is unusual in that keys
must be UUIDs, in order to prevent different clients setting the same
key, and values are write-once, read-once to minimize the chances of
state persisting indefinitely. The stash defines two operations;
`put`, to add state and `take` to remove state. Furthermore, the view
of the stash is path-specific; by default a request will only see the
part of the stash corresponding to its own path.
A typical example of using a stash to store state might be::
def handler(request, response):
# We assume this is a string representing a UUID
key = request.GET.first("id")
if request.method == "POST":
request.server.stash.put(key, "Some sample value")
return "Added value to stash"
value = request.server.stash.take(key)
assert request.server.stash.take(key) is None
return key
:mod:`Interface <wptserve.stash>`
.. automodule:: wptserve.stash
md5: {{file_hash(md5, sub_file_hash_subject.txt)}}
sha1: {{file_hash(sha1, sub_file_hash_subject.txt)}}
sha224: {{file_hash(sha224, sub_file_hash_subject.txt)}}
sha256: {{file_hash(sha256, sub_file_hash_subject.txt)}}
sha384: {{file_hash(sha384, sub_file_hash_subject.txt)}}
sha512: {{file_hash(sha512, sub_file_hash_subject.txt)}}
@ -0,0 +1,2 @@
This file is used to verify expected behavior of the `file_hash` "sub"
@ -0,0 +1 @@
{{file_hash(sha007, sub_file_hash_subject.txt)}}
@ -0,0 +1,2 @@
{{header_or_default(X-Present, present-default)}}
{{header_or_default(X-Absent, absent-default)}}
@ -0,0 +1 @@
@ -0,0 +1 @@
@ -0,0 +1,8 @@
host: {{location[host]}}
hostname: {{location[hostname]}}
path: {{location[path]}}
pathname: {{location[pathname]}}
port: {{location[port]}}
query: {{location[query]}}
scheme: {{location[scheme]}}
server: {{location[server]}}
@ -0,0 +1 @@
{{GET[plus pct-20 pct-3D=]}}
@ -0,0 +1 @@
{{GET[plus pct-20 pct-3D=]}}
@ -0,0 +1 @@
Before {{url_base}} After
@ -0,0 +1 @@
Before {{uuid()}} After
@ -0,0 +1 @@
{{$first:host}} {{$second:ports[http][0]}} A {{$second}} B {{$first}} C
@ -0,0 +1,2 @@
def module_function():
return [("Content-Type", "text/plain")], "PASS"
@ -0,0 +1 @@
I am here to ensure that my containing directory exists.
@ -0,0 +1,5 @@
from subdir import example_module
def main(request, response):
return example_module.module_function()
@ -0,0 +1,3 @@
@ -0,0 +1,5 @@
HTTP/1.1 202 Giraffe
Content-Length: 7
@ -0,0 +1,2 @@
def handle_data(frame, request, response):
response.content = frame.data[::-1]
@ -0,0 +1,3 @@
def handle_headers(frame, request, response):
response.status = 203
response.headers.update([('test', 'passed')])
@ -0,0 +1,6 @@
def handle_headers(frame, request, response):
response.status = 203
response.headers.update([('test', 'passed')])
def handle_data(frame, request, response):
response.content = frame.data[::-1]
@ -0,0 +1,3 @@
def main(request, response):
response.headers.set("Content-Type", "text/plain")
return "PASS"
@ -0,0 +1,2 @@
def main(request, response):
return [("Content-Type", "text/html"), ("X-Test", "PASS")], "PASS"
@ -0,0 +1,2 @@
def main(request, response):
return (202, "Giraffe"), [("Content-Type", "text/html"), ("X-Test", "PASS")], "PASS"
@ -0,0 +1 @@
Test document with custom headers
@ -0,0 +1,6 @@
Custom-Header: PASS
Another-Header: {{$id:uuid()}}
Same-Value-Header: {{$id}}
Double-Header: PA
Double-Header: SS
Content-Type: text/html
import json
import os
import sys
import unittest
import uuid
import pytest
from six.moves.urllib.error import HTTPError
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, TestUsingH2Server, doc_root
from .base import TestWrapperHandlerUsingServer
from serve import serve
class TestFileHandler(TestUsingServer):
def test_GET(self):
resp = self.request("/document.txt")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(open(os.path.join(doc_root, "document.txt"), 'rb').read(), resp.read())
def test_headers(self):
resp = self.request("/with_headers.txt")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["Custom-Header"])
# This will fail if it isn't a valid uuid
self.assertEqual(resp.info()["Same-Value-Header"], resp.info()["Another-Header"])
self.assert_multiple_headers(resp, "Double-Header", ["PA", "SS"])
def test_range(self):
resp = self.request("/document.txt", headers={"Range":"bytes=10-19"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(10, len(data))
self.assertEqual("bytes 10-19/%i" % len(expected), resp.info()['Content-Range'])
self.assertEqual("10", resp.info()['Content-Length'])
self.assertEqual(expected[10:20], data)
def test_range_no_end(self):
resp = self.request("/document.txt", headers={"Range":"bytes=10-"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(len(expected) - 10, len(data))
self.assertEqual("bytes 10-%i/%i" % (len(expected) - 1, len(expected)), resp.info()['Content-Range'])
self.assertEqual(expected[10:], data)
def test_range_no_start(self):
resp = self.request("/document.txt", headers={"Range":"bytes=-10"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(10, len(data))
self.assertEqual("bytes %i-%i/%i" % (len(expected) - 10, len(expected) - 1, len(expected)),
self.assertEqual(expected[-10:], data)
def test_multiple_ranges(self):
resp = self.request("/document.txt", headers={"Range":"bytes=1-2,5-7,6-10"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertTrue(resp.info()["Content-Type"].startswith("multipart/byteranges; boundary="))
boundary = resp.info()["Content-Type"].split("boundary=")[1]
parts = data.split(b"--" + boundary.encode("ascii"))
self.assertEqual(b"\r\n", parts[0])
self.assertEqual(b"--", parts[-1])
expected_parts = [(b"1-2", expected[1:3]), (b"5-10", expected[5:11])]
for expected_part, part in zip(expected_parts, parts[1:-1]):
header_string, body = part.split(b"\r\n\r\n")
headers = dict(item.split(b": ", 1) for item in header_string.split(b"\r\n") if item.strip())
self.assertEqual(headers[b"Content-Type"], b"text/plain")
self.assertEqual(headers[b"Content-Range"], b"bytes %s/%i" % (expected_part[0], len(expected)))
self.assertEqual(expected_part[1] + b"\r\n", body)
def test_range_invalid(self):
with self.assertRaises(HTTPError) as cm:
self.request("/document.txt", headers={"Range":"bytes=11-10"})
self.assertEqual(cm.exception.code, 416)
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
with self.assertRaises(HTTPError) as cm:
self.request("/document.txt", headers={"Range":"bytes=%i-%i" % (len(expected), len(expected) + 10)})
self.assertEqual(cm.exception.code, 416)
def test_sub_config(self):
resp = self.request("/sub.sub.txt")
expected = b"localhost localhost %i" % self.server.port
assert resp.read().rstrip() == expected
def test_sub_headers(self):
resp = self.request("/sub_headers.sub.txt", headers={"X-Test": "PASS"})
expected = b"PASS"
assert resp.read().rstrip() == expected
def test_sub_params(self):
resp = self.request("/sub_params.txt", query="plus+pct-20%20pct-3D%3D=PLUS+PCT-20%20PCT-3D%3D&pipe=sub")
expected = b"PLUS PCT-20 PCT-3D="
assert resp.read().rstrip() == expected
class TestFunctionHandler(TestUsingServer):
def test_string_rv(self):
def handler(request, response):
return "test data"
route = ("GET", "/test/test_string_rv", handler)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("9", resp.info()["Content-Length"])
self.assertEqual(b"test data", resp.read())
def test_tuple_1_rv(self):
def handler(request, response):
return ()
route = ("GET", "/test/test_tuple_1_rv", handler)
with pytest.raises(HTTPError) as cm:
assert cm.value.code == 500
def test_tuple_2_rv(self):
def handler(request, response):
return [("Content-Length", 4), ("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_2_rv", handler)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("4", resp.info()["Content-Length"])
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test", resp.read())
def test_tuple_3_rv(self):
def handler(request, response):
return 202, [("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_3_rv", handler)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test data", resp.read())
def test_tuple_3_rv_1(self):
def handler(request, response):
return (202, "Some Status"), [("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_3_rv_1", handler)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("Some Status", resp.msg)
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test data", resp.read())
def test_tuple_4_rv(self):
def handler(request, response):
return 202, [("test-header", "test-value")], "test data", "garbage"
route = ("GET", "/test/test_tuple_1_rv", handler)
with pytest.raises(HTTPError) as cm:
assert cm.value.code == 500
def test_none_rv(self):
def handler(request, response):
return None
route = ("GET", "/test/test_none_rv", handler)
resp = self.request(route[1])
assert resp.getcode() == 200
assert "Content-Length" not in resp.info()
assert resp.read() == b""
class TestJSONHandler(TestUsingServer):
def test_json_0(self):
def handler(request, response):
return {"data": "test data"}
route = ("GET", "/test/test_json_0", handler)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual({"data": "test data"}, json.load(resp))
def test_json_tuple_2(self):
def handler(request, response):
return [("Test-Header", "test-value")], {"data": "test data"}
route = ("GET", "/test/test_json_tuple_2", handler)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual({"data": "test data"}, json.load(resp))
def test_json_tuple_3(self):
def handler(request, response):
return (202, "Giraffe"), [("Test-Header", "test-value")], {"data": "test data"}
route = ("GET", "/test/test_json_tuple_2", handler)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual({"data": "test data"}, json.load(resp))
class TestPythonHandler(TestUsingServer):
def test_string(self):
resp = self.request("/test_string.py")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(b"PASS", resp.read())
def test_tuple_2(self):
resp = self.request("/test_tuple_2.py")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"PASS", resp.read())
def test_tuple_3(self):
resp = self.request("/test_tuple_3.py")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"PASS", resp.read())
def test_import(self):
dir_name = os.path.join(doc_root, "subdir")
assert dir_name not in sys.path
assert "test_module" not in sys.modules
resp = self.request("/subdir/import_handler.py")
assert dir_name not in sys.path
assert "test_module" not in sys.modules
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(b"PASS", resp.read())
def test_no_main(self):
with pytest.raises(HTTPError) as cm:
assert cm.value.code == 500
def test_invalid(self):
with pytest.raises(HTTPError) as cm:
assert cm.value.code == 500
def test_missing(self):
with pytest.raises(HTTPError) as cm:
assert cm.value.code == 404
class TestDirectoryHandler(TestUsingServer):
def test_directory(self):
resp = self.request("/")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
#Add a check that the response is actually sane
def test_subdirectory_trailing_slash(self):
resp = self.request("/subdir/")
assert resp.getcode() == 200
assert resp.info()["Content-Type"] == "text/html"
def test_subdirectory_no_trailing_slash(self):
# This seems to resolve the 301 transparently, so test for 200
resp = self.request("/subdir")
assert resp.getcode() == 200
assert resp.info()["Content-Type"] == "text/html"
class TestAsIsHandler(TestUsingServer):
def test_as_is(self):
resp = self.request("/test.asis")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"Content", resp.read())
#Add a check that the response is actually sane
class TestH2Handler(TestUsingH2Server):
def test_handle_headers(self):
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
def test_only_main(self):
self.conn.request("GET", '/test_tuple_3.py')
resp = self.conn.get_response()
assert resp.status == 202
assert resp.headers['Content-Type'][0] == b'text/html'
assert resp.headers['X-Test'][0] == b'PASS'
assert resp.read() == b'PASS'
def test_handle_data(self):
self.conn.request("POST", '/test_h2_data.py', body="hello world!")
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b'!dlrow olleh'
def test_handle_headers_data(self):
self.conn.request("POST", '/test_h2_headers_data.py', body="hello world!")
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b'!dlrow olleh'
def test_no_main_or_handlers(self):
self.conn.request("GET", '/no_main.py')
resp = self.conn.get_response()
assert resp.status == 500
assert "No main function or handlers in script " in json.loads(resp.read())["error"]["message"]
def test_not_found(self):
self.conn.request("GET", '/no_exist.py')
resp = self.conn.get_response()
assert resp.status == 404
def test_requesting_multiple_resources(self):
# 1st .py resource
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
# 2nd .py resource
self.conn.request("GET", '/test_tuple_3.py')
resp = self.conn.get_response()
assert resp.status == 202
assert resp.headers['Content-Type'][0] == b'text/html'
assert resp.headers['X-Test'][0] == b'PASS'
assert resp.read() == b'PASS'
# 3rd .py resource
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
class TestWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.worker.js': b'',
'foo.any.js': b''}
def test_any_worker_html(self):
'text/html', serve.WorkersHandler)
def test_worker_html(self):
'text/html', serve.WorkersHandler)
class TestWindowHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.window.js': b''}
def test_window_html(self):
'text/html', serve.WindowHandler)
class TestAnyHtmlHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'',
'foo.any.js.headers': b'X-Foo: 1',
'__dir__.headers': b'X-Bar: 2'}
def test_any_html(self):
headers=[('X-Foo', '1'), ('X-Bar', '2')])
class TestSharedWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'// META: global=sharedworker\n'}
def test_any_sharedworkers_html(self):
'text/html', serve.SharedWorkersHandler)
class TestServiceWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'// META: global=serviceworker\n'}
def test_serviceworker_html(self):
'text/html', serve.ServiceWorkersHandler)
class TestAnyWorkerHandler(TestWrapperHandlerUsingServer):
dummy_files = {'bar.any.js': b''}
def test_any_work_js(self):
self.run_wrapper_test('bar.any.worker.js', 'text/javascript',
if __name__ == '__main__':
import os
import unittest
import time
import json
from six import assertRegex
from six.moves import urllib
import pytest
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, doc_root
class TestStatus(TestUsingServer):
def test_status(self):
resp = self.request("/document.txt", query="pipe=status(202)")
self.assertEqual(resp.getcode(), 202)
class TestHeader(TestUsingServer):
def test_not_set(self):
resp = self.request("/document.txt", query="pipe=header(X-TEST,PASS)")
self.assertEqual(resp.info()["X-TEST"], "PASS")
def test_set(self):
resp = self.request("/document.txt", query="pipe=header(Content-Type,text/html)")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple(self):
resp = self.request("/document.txt", query="pipe=header(X-Test,PASS)|header(Content-Type,text/html)")
self.assertEqual(resp.info()["X-TEST"], "PASS")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple_same(self):
resp = self.request("/document.txt", query="pipe=header(Content-Type,FAIL)|header(Content-Type,text/html)")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple_append(self):
resp = self.request("/document.txt", query="pipe=header(X-Test,1)|header(X-Test,2,True)")
self.assert_multiple_headers(resp, "X-Test", ["1", "2"])
def test_semicolon(self):
resp = self.request("/document.txt", query="pipe=header(Refresh,3;url=http://example.com)")
self.assertEqual(resp.info()["Refresh"], "3;url=http://example.com")
def test_escape_comma(self):
resp = self.request("/document.txt", query=r"pipe=header(Expires,Thu\,%2014%20Aug%201986%2018:00:00%20GMT)")
self.assertEqual(resp.info()["Expires"], "Thu, 14 Aug 1986 18:00:00 GMT")
def test_escape_parenthesis(self):
resp = self.request("/document.txt", query=r"pipe=header(User-Agent,Mozilla/5.0%20(X11;%20Linux%20x86_64;%20rv:12.0\)")
self.assertEqual(resp.info()["User-Agent"], "Mozilla/5.0 (X11; Linux x86_64; rv:12.0)")
class TestSlice(TestUsingServer):
def test_both_bounds(self):
resp = self.request("/document.txt", query="pipe=slice(1,10)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[1:10])
def test_no_upper(self):
resp = self.request("/document.txt", query="pipe=slice(1)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[1:])
def test_no_lower(self):
resp = self.request("/document.txt", query="pipe=slice(null,10)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[:10])
class TestSub(TestUsingServer):
def test_sub_config(self):
resp = self.request("/sub.txt", query="pipe=sub")
expected = b"localhost localhost %i" % self.server.port
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_file_hash(self):
resp = self.request("/sub_file_hash.sub.txt")
expected = b"""
md5: JmI1W8fMHfSfCarYOSxJcw==
sha1: nqpWqEw4IW8NjD6R375gtrQvtTo=
sha224: RqQ6fMmta6n9TuA/vgTZK2EqmidqnrwBAmQLRQ==
sha256: G6Ljg1uPejQxqFmvFOcV/loqnjPTW5GSOePOfM/u0jw=
sha384: lkXHChh1BXHN5nT5BYhi1x67E1CyYbPKRKoF2LTm5GivuEFpVVYtvEBHtPr74N9E
sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICzJwGFonfXwg=="""
self.assertEqual(resp.read().rstrip(), expected.strip())
def test_sub_file_hash_unrecognized(self):
with self.assertRaises(urllib.error.HTTPError):
def test_sub_headers(self):
resp = self.request("/sub_headers.txt", query="pipe=sub", headers={"X-Test": "PASS"})
expected = b"PASS"
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_location(self):
resp = self.request("/sub_location.sub.txt?query_string")
expected = """
host: localhost:{0}
hostname: localhost
path: /sub_location.sub.txt
pathname: /sub_location.sub.txt
port: {0}
query: ?query_string
scheme: http
server: http://localhost:{0}""".format(self.server.port).encode("ascii")
self.assertEqual(resp.read().rstrip(), expected.strip())
def test_sub_params(self):
resp = self.request("/sub_params.txt", query="plus+pct-20%20pct-3D%3D=PLUS+PCT-20%20PCT-3D%3D&pipe=sub")
expected = b"PLUS PCT-20 PCT-3D="
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_url_base(self):
resp = self.request("/sub_url_base.sub.txt")
self.assertEqual(resp.read().rstrip(), b"Before / After")
def test_sub_url_base_via_filename_with_query(self):
resp = self.request("/sub_url_base.sub.txt?pipe=slice(5,10)")
self.assertEqual(resp.read().rstrip(), b"e / A")
def test_sub_uuid(self):
resp = self.request("/sub_uuid.sub.txt")
assertRegex(self, resp.read().rstrip(), b"Before [a-f0-9-]+ After")
def test_sub_var(self):
resp = self.request("/sub_var.sub.txt")
port = self.server.port
expected = b"localhost %d A %d B localhost C" % (port, port)
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_fs_path(self):
resp = self.request("/subdir/sub_path.sub.txt")
root = os.path.abspath(doc_root)
expected = """%(root)s%(sep)ssubdir%(sep)ssub_path.sub.txt
""" % {"root": root, "sep": os.path.sep}
self.assertEqual(resp.read(), expected.encode("utf8"))
def test_sub_header_or_default(self):
resp = self.request("/sub_header_or_default.sub.txt", headers={"X-Present": "OK"})
expected = b"OK\nabsent-default"
self.assertEqual(resp.read().rstrip(), expected)
class TestTrickle(TestUsingServer):
def test_trickle(self):
#Actually testing that the response trickles in is not that easy
t0 = time.time()
resp = self.request("/document.txt", query="pipe=trickle(1:d2:5:d1:r2)")
t1 = time.time()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected)
self.assertGreater(6, t1-t0)
def test_headers(self):
resp = self.request("/document.txt", query="pipe=trickle(d0.01)")
self.assertEqual(resp.info()["Cache-Control"], "no-cache, no-store, must-revalidate")
self.assertEqual(resp.info()["Pragma"], "no-cache")
self.assertEqual(resp.info()["Expires"], "0")
class TestPipesWithVariousHandlers(TestUsingServer):
def test_with_python_file_handler(self):
resp = self.request("/test_string.py", query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b"PA")
def test_with_python_func_handler(self):
def handler(request, response):
return "PASS"
route = ("GET", "/test/test_pipes_1/", handler)
resp = self.request(route[1], query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b"PA")
def test_with_python_func_handler_using_response_writer(self):
def handler(request, response):
route = ("GET", "/test/test_pipes_1/", handler)
resp = self.request(route[1], query="pipe=slice(null,2)")
# slice has not been applied to the response, because response.writer was used.
self.assertEqual(resp.read(), b"PASS")
def test_header_pipe_with_python_func_using_response_writer(self):
def handler(request, response):
route = ("GET", "/test/test_pipes_1/", handler)
resp = self.request(route[1], query="pipe=header(X-TEST,FAIL)")
# header pipe was ignored, because response.writer was used.
self.assertEqual(resp.read(), b"CONTENT")
def test_with_json_handler(self):
def handler(request, response):
return json.dumps({'data': 'PASS'})
route = ("GET", "/test/test_pipes_2/", handler)
resp = self.request(route[1], query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b'"{')
def test_slice_with_as_is_handler(self):
resp = self.request("/test.asis", query="pipe=slice(null,2)")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("PASS", resp.info()["X-Test"])
# slice has not been applied to the response, because response.writer was used.
self.assertEqual(b"Content", resp.read())
def test_headers_with_as_is_handler(self):
resp = self.request("/test.asis", query="pipe=header(X-TEST,FAIL)")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
# header pipe was ignored.
self.assertEqual("PASS", resp.info()["X-TEST"])
self.assertEqual(b"Content", resp.read())
def test_trickle_with_as_is_handler(self):
t0 = time.time()
resp = self.request("/test.asis", query="pipe=trickle(1:d2:5:d1:r2)")
t1 = time.time()
self.assertTrue(b'Content' in resp.read())
self.assertGreater(6, t1-t0)
if __name__ == '__main__':
import os
import unittest
import json
from io import BytesIO
import pytest
from six import create_bound_method, PY3
from six.moves.http_client import BadStatusLine
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, TestUsingH2Server, doc_root
from h2.exceptions import ProtocolError
def send_body_as_header(self):
if self._response.add_required_headers:
self.write("X-Body: ")
self._headers_complete = True
class TestResponse(TestUsingServer):
def test_head_without_body(self):
def handler(request, response):
response.writer.end_headers = create_bound_method(send_body_as_header,
return [("X-Test", "TEST")], "body\r\n"
route = ("GET", "/test/test_head_without_body", handler)
resp = self.request(route[1], method="HEAD")
self.assertEqual("6", resp.info()['Content-Length'])
self.assertEqual("TEST", resp.info()['x-Test'])
self.assertEqual("", resp.info()['x-body'])
def test_head_with_body(self):
def handler(request, response):
response.send_body_for_head_request = True
response.writer.end_headers = create_bound_method(send_body_as_header,
return [("X-Test", "TEST")], "body\r\n"
route = ("GET", "/test/test_head_with_body", handler)
resp = self.request(route[1], method="HEAD")
self.assertEqual("6", resp.info()['Content-Length'])
self.assertEqual("TEST", resp.info()['x-Test'])
self.assertEqual("body", resp.info()['X-Body'])
def test_write_content_no_status_no_header(self):
resp_content = b"TEST"
def handler(request, response):
route = ("GET", "/test/test_write_content_no_status_no_header", handler)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info()["Content-Length"] == str(len(resp_content))
assert "Date" in resp.info()
assert "Server" in resp.info()
def test_write_content_no_headers(self):
resp_content = b"TEST"
def handler(request, response):
route = ("GET", "/test/test_write_content_no_headers", handler)
resp = self.request(route[1])
assert resp.getcode() == 201
assert resp.read() == resp_content
assert resp.info()["Content-Length"] == str(len(resp_content))
assert "Date" in resp.info()
assert "Server" in resp.info()
def test_write_content_no_status(self):
resp_content = b"TEST"
def handler(request, response):
response.writer.write_header("test-header", "test-value")
route = ("GET", "/test/test_write_content_no_status", handler)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert sorted([x.lower() for x in resp.info().keys()]) == sorted(['test-header', 'date', 'server', 'content-length'])
def test_write_content_no_status_no_required_headers(self):
resp_content = b"TEST"
def handler(request, response):
response.add_required_headers = False
response.writer.write_header("test-header", "test-value")
route = ("GET", "/test/test_write_content_no_status_no_required_headers", handler)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info().items() == [('test-header', 'test-value')]
def test_write_content_no_status_no_headers_no_required_headers(self):
resp_content = b"TEST"
def handler(request, response):
response.add_required_headers = False
route = ("GET", "/test/test_write_content_no_status_no_headers_no_required_headers", handler)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info().items() == []
def test_write_raw_content(self):
resp_content = b"HTTP/1.1 202 Giraffe\n" \
b"X-TEST: PASS\n" \
b"Content-Length: 7\n\n" \
def handler(request, response):
route = ("GET", "/test/test_write_raw_content", handler)
resp = self.request(route[1])
assert resp.getcode() == 202
assert resp.info()["X-TEST"] == "PASS"
assert resp.read() == b"Content"
def test_write_raw_content_file(self):
def handler(request, response):
with open(os.path.join(doc_root, "test.asis"), 'rb') as infile:
route = ("GET", "/test/test_write_raw_content", handler)
resp = self.request(route[1])
assert resp.getcode() == 202
assert resp.info()["X-TEST"] == "PASS"
assert resp.read() == b"Content"
def test_write_raw_none(self):
def handler(request, response):
with pytest.raises(ValueError):
route = ("GET", "/test/test_write_raw_content", handler)
def test_write_raw_contents_invalid_http(self):
resp_content = b"INVALID HTTP"
def handler(request, response):
route = ("GET", "/test/test_write_raw_content", handler)
resp = self.request(route[1])
assert resp.read() == resp_content
except BadStatusLine as e:
# In Python3, an invalid HTTP request should throw BadStatusLine.
assert PY3
assert str(e) == resp_content.decode('utf-8')
class TestH2Response(TestUsingH2Server):
def test_write_without_ending_stream(self):
data = b"TEST"
def handler(request, response):
headers = [
('server', 'test-h2'),
('test', 'PASS'),
response.writer.write_headers(headers, 202)
response.writer.write_data_frame(data, False)
# Should detect stream isn't ended and call `writer.end_stream()`
route = ("GET", "/h2test/test", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 202
assert [x for x in resp.headers.items()] == [(b'server', b'test-h2'), (b'test', b'PASS')]
assert resp.read() == data
def test_push(self):
data = b"TEST"
push_data = b"PUSH TEST"
def handler(request, response):
headers = [
('server', 'test-h2'),
('test', 'PASS'),
response.writer.write_headers(headers, 202)
promise_headers = [
(':method', 'GET'),
(':path', '/push-test'),
(':scheme', 'https'),
(':authority', '%s:%i' % (self.server.host, self.server.port))
push_headers = [
('server', 'test-h2'),
('content-length', str(len(push_data))),
('content-type', 'text'),
response.writer.write_data_frame(data, True)
route = ("GET", "/h2test/test_push", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 202
assert [x for x in resp.headers.items()] == [(b'server', b'test-h2'), (b'test', b'PASS')]
assert resp.read() == data
push_promise = next(self.conn.get_pushes())
push = push_promise.get_response()
assert push_promise.path == b'/push-test'
assert push.status == 203
assert push.read() == push_data
def test_set_error(self):
def handler(request, response):
response.set_error(503, message="Test error")
route = ("GET", "/h2test/test_set_error", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 503
assert json.loads(resp.read()) == json.loads("{\"error\": {\"message\": \"Test error\", \"code\": 503}}")
def test_file_like_response(self):
def handler(request, response):
content = BytesIO(b"Hello, world!")
response.content = content
route = ("GET", "/h2test/test_file_like_response", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b"Hello, world!"
def test_list_response(self):
def handler(request, response):
response.content = ['hello', 'world']
route = ("GET", "/h2test/test_file_like_response", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b"helloworld"
def test_content_longer_than_frame_size(self):
def handler(request, response):
size = response.writer.get_max_payload_size()
content = "a" * (size + 5)
return [('payload_size', size)], content
route = ("GET", "/h2test/test_content_longer_than_frame_size", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
payload_size = int(resp.headers['payload_size'][0])
assert payload_size
assert resp.read() == b"a" * (payload_size + 5)
def test_encode(self):
def handler(request, response):
response.encoding = "utf8"
t = response.writer.encode(u"hello")
assert t == "hello"
with pytest.raises(ValueError):
route = ("GET", "/h2test/test_content_longer_than_frame_size", handler)
self.conn.request(route[0], route[1])
def test_raw_header_frame(self):
def handler(request, response):
(':status', '204'),
('server', 'TEST-H2')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 204
assert resp.headers['server'][0] == b'TEST-H2'
assert resp.read() == b''
def test_raw_header_frame_invalid(self):
def handler(request, response):
('server', 'TEST-H2'),
(':status', '204')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.conn.request(route[0], route[1])
with pytest.raises(ProtocolError):
# The server can send an invalid HEADER frame, which will cause a protocol error in client
def test_raw_data_frame(self):
def handler(request, response):
response.writer.write_raw_data_frame(data=b'Hello world', end_stream=True)
route = ("GET", "/h2test/test_file_like_response", handler)
sid = self.conn.request(route[0], route[1])
assert self.conn.streams[sid]._read() == b'Hello world'
def test_raw_header_continuation_frame(self):
def handler(request, response):
(':status', '204')
('server', 'TEST-H2')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 204
assert resp.headers['server'][0] == b'TEST-H2'
assert resp.read() == b''
if __name__ == '__main__':
import unittest
import uuid
import pytest
wptserve = pytest.importorskip("wptserve")
from wptserve.router import any_method
from wptserve.stash import StashServer
from .base import TestUsingServer
class TestResponseSetCookie(TestUsingServer):
def run(self, result=None):
with StashServer(None, authkey=str(uuid.uuid4())):
super(TestResponseSetCookie, self).run(result)
def test_put_take(self):
def handler(request, response):
if request.method == "POST":
request.server.stash.put(request.POST.first(b"id"), request.POST.first(b"data"))
data = "OK"
elif request.method == "GET":
data = request.server.stash.take(request.GET.first(b"id"))
if data is None:
return "NOT FOUND"
return data
id = str(uuid.uuid4())
route = (any_method, "/test/put_take", handler)
resp = self.request(route[1], method="POST", body={"id": id, "data": "Sample data"})
self.assertEqual(resp.read(), b"OK")
resp = self.request(route[1], query="id=" + id)
self.assertEqual(resp.read(), b"Sample data")
resp = self.request(route[1], query="id=" + id)
self.assertEqual(resp.read(), b"NOT FOUND")
if __name__ == '__main__':
from __future__ import unicode_literals
import pytest
from wptserve.pipes import ReplacementTokenizer
[b"aaa", [('ident', 'aaa')]],
[b"bbb()", [('ident', 'bbb'), ('arguments', [])]],
[b"bcd(uvw, xyz)", [('ident', 'bcd'), ('arguments', ['uvw', 'xyz'])]],
[b"$ccc:ddd", [('var', '$ccc'), ('ident', 'ddd')]],
[b"$eee", [('ident', '$eee')]],
[b"fff[0]", [('ident', 'fff'), ('index', 0)]],
[b"ggg[hhh]", [('ident', 'ggg'), ('index', 'hhh')]],
[b"[iii]", [('index', 'iii')]],
[b"jjj['kkk']", [('ident', 'jjj'), ('index', "'kkk'")]],
[b"lll[]", [('ident', 'lll'), ('index', "")]],
[b"111", [('ident', '111')]],
[b"$111", [('ident', '$111')]],
def test_tokenizer(content, expected):
tokenizer = ReplacementTokenizer()
tokens = tokenizer.tokenize(content)
assert expected == tokens
[b"/", []],
[b"$aaa: BBB", [('var', '$aaa')]],
def test_tokenizer_errors(content, expected):
tokenizer = ReplacementTokenizer()
tokens = tokenizer.tokenize(content)
assert expected == tokens
import mock
from six import BytesIO
from wptserve.response import Response
def test_response_status():
cases = [200, (200, b'OK'), (200, u'OK'), ('200', 'OK')]
for case in cases:
handler = mock.Mock()
handler.wfile = BytesIO()
request = mock.Mock()
request.protocol_version = 'HTTP/1.1'
response = Response(handler, request)
response.status = case
expected = case if isinstance(case, tuple) else (case, None)
if expected[0] == '200':
expected = (200, expected[1])
assert response.status == expected
assert handler.wfile.getvalue() == b'HTTP/1.1 200 OK\r\n'
def test_response_status_not_string():
# This behaviour is not documented but kept for backward compatibility.
handler = mock.Mock()
request = mock.Mock()
response = Response(handler, request)
response.status = (200, 100)
assert response.status == (200, '100')
from .server import WebTestHttpd, WebTestServer, Router # noqa: F401
from .request import Request # noqa: F401
from .response import Response # noqa: F401
@ -0,0 +1,353 @@
import copy
import logging
import os
from collections import defaultdict
from six.moves.collections_abc import Mapping
from six import integer_types, iteritems, itervalues, string_types
from . import sslutils
from .utils import get_port
_renamed_props = {
"host": "browser_host",
"bind_hostname": "bind_address",
"external_host": "server_host",
"host_ip": "server_host",
def _merge_dict(base_dict, override_dict):
rv = base_dict.copy()
for key, value in iteritems(base_dict):
if key in override_dict:
if isinstance(value, dict):
rv[key] = _merge_dict(value, override_dict[key])
rv[key] = override_dict[key]
return rv
class Config(Mapping):
"""wptserve config
Inherits from Mapping for backwards compatibility with the old dict-based config"""
def __init__(self, logger_name, data):
self.__dict__["_logger_name"] = logger_name
def __str__(self):
return str(self.__dict__)
def __setattr__(self, key, value):
raise ValueError("Config is immutable")
def __setitem__(self, key):
raise ValueError("Config is immutable")
def __getitem__(self, key):
return getattr(self, key)
except AttributeError:
raise ValueError
def __contains__(self, key):
return key in self.__dict__
def __iter__(self):
return (x for x in self.__dict__ if not x.startswith("_"))
def __len__(self):
return len([item for item in self])
def logger(self):
logger = logging.getLogger(self._logger_name)
return logger
def as_dict(self):
return json_types(self.__dict__)
# Environment variables are limited in size so we need to prune the most egregious contributors
# to size, the origin policy subdomains.
def as_dict_for_wd_env_variable(self):
result = self.as_dict()
for key in [
("domains", "alt"),
("domains", ""),
("all_domains", "alt"),
("all_domains", ""),
target = result
for part in key[:-1]:
target = target[part]
value = target[key[-1]]
if isinstance(value, dict):
target[key[-1]] = {k:v for (k,v) in iteritems(value) if not k.startswith("op")}
target[key[-1]] = [x for x in value if not x.startswith("op")]
return result
def json_types(obj):
if isinstance(obj, dict):
return {key: json_types(value) for key, value in iteritems(obj)}
if (isinstance(obj, string_types) or
isinstance(obj, integer_types) or
isinstance(obj, float) or
isinstance(obj, bool) or
obj is None):
return obj
if isinstance(obj, list) or hasattr(obj, "__iter__"):
return [json_types(value) for value in obj]
raise ValueError
class ConfigBuilder(object):
"""Builder object for setting the wptsync config.
Configuration can be passed in as a dictionary to the constructor, or
set via attributes after construction. Configuration options must match
the keys on the _default class property.
The generated configuration is obtained by using the builder
object as a context manager; this returns a Config object
containing immutable configuration that may be shared between
threads and processes. In general the configuration is only valid
for the context used to obtain it.
with ConfigBuilder() as config:
# Use the configuration
print config.browser_host
The properties on the final configuration include those explicitly
supplied and computed properties. The computed properties are
defined by the computed_properties attribute on the class. This
is a list of property names, each corresponding to a _get_<name>
method on the class. These methods are called in the order defined
in computed_properties and are passed a single argument, a
dictionary containing the current set of properties. Thus computed
properties later in the list may depend on the value of earlier
_default = {
"browser_host": "localhost",
"alternate_hosts": {},
"doc_root": os.path.dirname("__file__"),
"server_host": None,
"ports": {"http": [8000]},
"check_subdomains": True,
"log_level": "debug",
"bind_address": True,
"ssl": {
"type": "none",
"encrypt_after_connect": False,
"none": {},
"openssl": {
"openssl_binary": "openssl",
"base_path": "_certs",
"password": "web-platform-tests",
"force_regenerate": False,
"duration": 30,
"base_conf_path": None
"pregenerated": {
"host_key_path": None,
"host_cert_path": None,
"aliases": []
default_config_cls = Config
# Configuration properties that are computed. Each corresponds to a method
# _get_foo, which is called with the current data dictionary. The properties
# are computed in the order specified in the list.
computed_properties = ["log_level",
def __init__(self,
self._data = self._default.copy()
self._ssl_env = None
self._config_cls = config_cls or self.default_config_cls
if logger is None:
self._logger_name = "web-platform-tests"
level_name = logging.getLevelName(logger.level)
if level_name != "NOTSET":
self.log_level = level_name
self._logger_name = logger.name
for k, v in iteritems(self._default):
self._data[k] = kwargs.pop(k, v)
self._data["subdomains"] = subdomains
self._data["not_subdomains"] = not_subdomains
for k, new_k in iteritems(_renamed_props):
if k in kwargs:
"%s in config is deprecated; use %s instead" % (
self._data[new_k] = kwargs.pop(k)
if kwargs:
raise TypeError("__init__() got unexpected keyword arguments %r" % (tuple(kwargs),))
def __setattr__(self, key, value):
if not key[0] == "_":
self._data[key] = value
self.__dict__[key] = value
def logger(self):
logger = logging.getLogger(self._logger_name)
return logger
def update(self, override):
"""Load an overrides dict to override config values"""
override = override.copy()
for k in self._default:
if k in override:
self._set_override(k, override.pop(k))
for k, new_k in iteritems(_renamed_props):
if k in override:
"%s in config is deprecated; use %s instead" % (
self._set_override(new_k, override.pop(k))
if override:
k = next(iter(override))
raise KeyError("unknown config override '%s'" % k)
def _set_override(self, k, v):
old_v = self._data[k]
if isinstance(old_v, dict):
self._data[k] = _merge_dict(old_v, v)
self._data[k] = v
def __enter__(self):
if self._ssl_env is not None:
raise ValueError("Tried to re-enter configuration")
data = self._data.copy()
prefix = "_get_"
for key in self.computed_properties:
data[key] = getattr(self, prefix + key)(data)
return self._config_cls(self._logger_name, data)
def __exit__(self, *args):
self._ssl_env = None
def _get_log_level(self, data):
return data["log_level"].upper()
def _get_paths(self, data):
return {"doc_root": data["doc_root"]}
def _get_server_host(self, data):
return data["server_host"] if data.get("server_host") is not None else data["browser_host"]
def _get_ports(self, data):
new_ports = defaultdict(list)
for scheme, ports in iteritems(data["ports"]):
if scheme in ["wss", "https"] and not sslutils.get_cls(data["ssl"]["type"]).ssl_enabled:
for i, port in enumerate(ports):
real_port = get_port("") if port == "auto" else port
return new_ports
def _get_domains(self, data):
hosts = data["alternate_hosts"].copy()
assert "" not in hosts
hosts[""] = data["browser_host"]
rv = {}
for name, host in iteritems(hosts):
rv[name] = {subdomain: (subdomain.encode("idna").decode("ascii") + u"." + host)
for subdomain in data["subdomains"]}
rv[name][""] = host
return rv
def _get_not_domains(self, data):
hosts = data["alternate_hosts"].copy()
assert "" not in hosts
hosts[""] = data["browser_host"]
rv = {}
for name, host in iteritems(hosts):
rv[name] = {subdomain: (subdomain.encode("idna").decode("ascii") + u"." + host)
for subdomain in data["not_subdomains"]}
return rv
def _get_all_domains(self, data):
rv = copy.deepcopy(data["domains"])
nd = data["not_domains"]
for host in rv:
return rv
def _get_domains_set(self, data):
return {domain
for per_host_domains in itervalues(data["domains"])
for domain in itervalues(per_host_domains)}
def _get_not_domains_set(self, data):
return {domain
for per_host_domains in itervalues(data["not_domains"])
for domain in itervalues(per_host_domains)}
def _get_all_domains_set(self, data):
return data["domains_set"] | data["not_domains_set"]
def _get_ssl_config(self, data):
ssl_type = data["ssl"]["type"]
ssl_cls = sslutils.get_cls(ssl_type)
kwargs = data["ssl"].get(ssl_type, {})
self._ssl_env = ssl_cls(self.logger, **kwargs)
if self._ssl_env.ssl_enabled:
key_path, cert_path = self._ssl_env.host_cert_path(data["domains_set"])
ca_cert_path = self._ssl_env.ca_cert_path(data["domains_set"])
return {"key_path": key_path,
"ca_cert_path": ca_cert_path,
"cert_path": cert_path,
"encrypt_after_connect": data["ssl"].get("encrypt_after_connect", False)}
@ -0,0 +1,97 @@
from . import utils
content_types = utils.invert_dict({
"application/json": ["json"],
"application/wasm": ["wasm"],
"application/xhtml+xml": ["xht", "xhtm", "xhtml"],
"application/xml": ["xml"],
"application/x-xpinstall": ["xpi"],
"audio/mp4": ["m4a"],
"audio/mpeg": ["mp3"],
"audio/ogg": ["oga"],
"audio/webm": ["weba"],
"audio/x-wav": ["wav"],
"image/bmp": ["bmp"],
"image/gif": ["gif"],
"image/jpeg": ["jpg", "jpeg"],
"image/png": ["png"],
"image/svg+xml": ["svg"],
"text/cache-manifest": ["manifest"],
"text/css": ["css"],
"text/event-stream": ["event_stream"],
"text/html": ["htm", "html"],
"text/javascript": ["js", "mjs"],
"text/plain": ["txt", "md"],
"text/vtt": ["vtt"],
"video/mp4": ["mp4", "m4v"],
"video/ogg": ["ogg", "ogv"],
"video/webm": ["webm"],
response_codes = {
100: ('Continue', 'Request received, please continue'),
101: ('Switching Protocols',
'Switching to new protocol; obey Upgrade header'),
200: ('OK', 'Request fulfilled, document follows'),
201: ('Created', 'Document created, URL follows'),
202: ('Accepted',
'Request accepted, processing continues off-line'),
203: ('Non-Authoritative Information', 'Request fulfilled from cache'),
204: ('No Content', 'Request fulfilled, nothing follows'),
205: ('Reset Content', 'Clear input form for further input.'),
206: ('Partial Content', 'Partial content follows.'),
300: ('Multiple Choices',
'Object has several resources -- see URI list'),
301: ('Moved Permanently', 'Object moved permanently -- see URI list'),
302: ('Found', 'Object moved temporarily -- see URI list'),
303: ('See Other', 'Object moved -- see Method and URL list'),
304: ('Not Modified',
'Document has not changed since given time'),
305: ('Use Proxy',
'You must use proxy specified in Location to access this '
307: ('Temporary Redirect',
'Object moved temporarily -- see URI list'),
400: ('Bad Request',
'Bad request syntax or unsupported method'),
401: ('Unauthorized',
'No permission -- see authorization schemes'),
402: ('Payment Required',
'No payment -- see charging schemes'),
403: ('Forbidden',
'Request forbidden -- authorization will not help'),
404: ('Not Found', 'Nothing matches the given URI'),
405: ('Method Not Allowed',
'Specified method is invalid for this resource.'),
406: ('Not Acceptable', 'URI not available in preferred format.'),
407: ('Proxy Authentication Required', 'You must authenticate with '
'this proxy before proceeding.'),
408: ('Request Timeout', 'Request timed out; try again later.'),
409: ('Conflict', 'Request conflict.'),
410: ('Gone',
'URI no longer exists and has been permanently removed.'),
411: ('Length Required', 'Client must specify Content-Length.'),
412: ('Precondition Failed', 'Precondition in headers is false.'),
413: ('Request Entity Too Large', 'Entity is too large.'),
414: ('Request-URI Too Long', 'URI is too long.'),
415: ('Unsupported Media Type', 'Entity body in unsupported format.'),
416: ('Requested Range Not Satisfiable',
'Cannot satisfy request range.'),
417: ('Expectation Failed',
'Expect condition could not be satisfied.'),
500: ('Internal Server Error', 'Server got itself in trouble'),
501: ('Not Implemented',
'Server does not support this operation'),
502: ('Bad Gateway', 'Invalid responses from another server/proxy.'),
503: ('Service Unavailable',
'The server cannot process the request due to a high load'),
504: ('Gateway Timeout',
'The gateway server did not receive a timely response'),
505: ('HTTP Version Not Supported', 'Cannot fulfill request.'),
h2_headers = ['method', 'scheme', 'host', 'path', 'authority', 'status']
@ -0,0 +1,514 @@
import json
import os
import traceback
from collections import defaultdict
from six.moves.urllib.parse import quote, unquote, urljoin
from six import iteritems
from .constants import content_types
from .pipes import Pipeline, template
from .ranges import RangeParser
from .request import Authentication
from .response import MultipartContent
from .utils import HTTPException
from html import escape
except ImportError:
from cgi import escape
__all__ = ["file_handler", "python_script_handler",
"FunctionHandler", "handler", "json_handler",
"as_is_handler", "ErrorHandler", "BasicAuthHandler"]
def guess_content_type(path):
ext = os.path.splitext(path)[1].lstrip(".")
if ext in content_types:
return content_types[ext]
return "application/octet-stream"
def filesystem_path(base_path, request, url_base="/"):
if base_path is None:
base_path = request.doc_root
path = unquote(request.url_parts.path)
if path.startswith(url_base):
path = path[len(url_base):]
if ".." in path:
raise HTTPException(404)
new_path = os.path.join(base_path, path)
# Otherwise setting path to / allows access outside the root directory
if not new_path.startswith(base_path):
raise HTTPException(404)
return new_path
class DirectoryHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def __call__(self, request, response):
url_path = request.url_parts.path
if not url_path.endswith("/"):
response.status = 301
response.headers = [("Location", "%s/" % request.url)]
path = filesystem_path(self.base_path, request, self.url_base)
assert os.path.isdir(path)
response.headers = [("Content-Type", "text/html")]
response.content = """<!doctype html>
<meta name="viewport" content="width=device-width">
<title>Directory listing for %(path)s</title>
<h1>Directory listing for %(path)s</h1>
""" % {"path": escape(url_path),
"items": "\n".join(self.list_items(url_path, path))} # noqa: E122
def list_items(self, base_path, path):
assert base_path.endswith("/")
# TODO: this won't actually list all routes, only the
# ones that correspond to a real filesystem path. It's
# not possible to list every route that will match
# something, but it should be possible to at least list the
# statically defined ones
if base_path != "/":
link = urljoin(base_path, "..")
yield ("""<li class="dir"><a href="%(link)s">%(name)s</a></li>""" %
{"link": link, "name": ".."})
items = []
prev_item = None
# This ensures that .headers always sorts after the file it provides the headers for. E.g.,
# if we have x, x-y, and x.headers, the order will be x, x.headers, and then x-y.
for item in sorted(os.listdir(path), key=lambda x: (x[:-len(".headers")], x) if x.endswith(".headers") else (x, x)):
if prev_item and prev_item + ".headers" == item:
items[-1][1] = item
prev_item = None
items.append([item, None])
prev_item = item
for item, dot_headers in items:
link = escape(quote(item))
dot_headers_markup = ""
if dot_headers is not None:
dot_headers_markup = (""" (<a href="%(link)s">.headers</a>)""" %
{"link": escape(quote(dot_headers))})
if os.path.isdir(os.path.join(path, item)):
link += "/"
class_ = "dir"
class_ = "file"
yield ("""<li class="%(class)s"><a href="%(link)s">%(name)s</a>%(headers)s</li>""" %
{"link": link, "name": escape(item), "class": class_,
"headers": dot_headers_markup})
def parse_qs(qs):
"""Parse a query string given as a string argument (data of type
application/x-www-form-urlencoded). Data are returned as a dictionary. The
dictionary keys are the unique query variable names and the values are
lists of values for each name.
This implementation is used instead of Python's built-in `parse_qs` method
in order to support the semicolon character (which the built-in method
interprets as a parameter delimiter)."""
pairs = [item.split("=", 1) for item in qs.split('&') if item]
rv = defaultdict(list)
for pair in pairs:
if len(pair) == 1 or len(pair[1]) == 0:
name = unquote(pair[0].replace('+', ' '))
value = unquote(pair[1].replace('+', ' '))
return dict(rv)
def wrap_pipeline(path, request, response):
"""Applies pipelines to a response.
Pipelines are specified in the filename (.sub.) or the query param (?pipe).
query = parse_qs(request.url_parts.query)
pipe_string = ""
if ".sub." in path:
ml_extensions = {".html", ".htm", ".xht", ".xhtml", ".xml", ".svg"}
escape_type = "html" if os.path.splitext(path)[1] in ml_extensions else "none"
pipe_string = "sub(%s)" % escape_type
if "pipe" in query:
if pipe_string:
pipe_string += "|"
pipe_string += query["pipe"][-1]
if pipe_string:
response = Pipeline(pipe_string)(request, response)
return response
def load_headers(request, path):
"""Loads headers from files for a given path.
Attempts to load both the neighbouring __dir__{.sub}.headers and
PATH{.sub}.headers (applying template substitution if needed); results are
concatenated in that order.
def _load(request, path):
headers_path = path + ".sub.headers"
if os.path.exists(headers_path):
use_sub = True
headers_path = path + ".headers"
use_sub = False
with open(headers_path, "rb") as headers_file:
data = headers_file.read()
except IOError:
return []
if use_sub:
data = template(request, data, escape_type="none")
return [tuple(item.strip() for item in line.split(b":", 1))
for line in data.splitlines() if line]
return (_load(request, os.path.join(os.path.dirname(path), "__dir__")) +
_load(request, path))
class FileHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
self.directory_handler = DirectoryHandler(self.base_path, self.url_base)
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def __call__(self, request, response):
path = filesystem_path(self.base_path, request, self.url_base)
if os.path.isdir(path):
return self.directory_handler(request, response)
#This is probably racy with some other process trying to change the file
file_size = os.stat(path).st_size
response.headers.update(self.get_headers(request, path))
if "Range" in request.headers:
byte_ranges = RangeParser()(request.headers['Range'], file_size)
except HTTPException as e:
if e.code == 416:
response.headers.set("Content-Range", "bytes */%i" % file_size)
byte_ranges = None
data = self.get_data(response, path, byte_ranges)
response.content = data
response = wrap_pipeline(path, request, response)
return response
except (OSError, IOError):
raise HTTPException(404)
def get_headers(self, request, path):
rv = load_headers(request, path)
if not any(key.lower() == b"content-type" for (key, _) in rv):
rv.insert(0, (b"Content-Type", guess_content_type(path).encode("ascii")))
return rv
def get_data(self, response, path, byte_ranges):
"""Return either the handle to a file, or a string containing
the content of a chunk of the file, if we have a range request."""
if byte_ranges is None:
return open(path, 'rb')
with open(path, 'rb') as f:
response.status = 206
if len(byte_ranges) > 1:
parts_content_type, content = self.set_response_multipart(response,
for byte_range in byte_ranges:
content.append_part(self.get_range_data(f, byte_range),
[("Content-Range", byte_range.header_value())])
return content
response.headers.set("Content-Range", byte_ranges[0].header_value())
return self.get_range_data(f, byte_ranges[0])
def set_response_multipart(self, response, ranges, f):
parts_content_type = response.headers.get("Content-Type")
if parts_content_type:
parts_content_type = parts_content_type[-1]
parts_content_type = None
content = MultipartContent()
response.headers.set("Content-Type", "multipart/byteranges; boundary=%s" % content.boundary)
return parts_content_type, content
def get_range_data(self, f, byte_range):
return f.read(byte_range.upper - byte_range.lower)
file_handler = FileHandler()
class PythonScriptHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def _load_file(self, request, response, func):
This loads the requested python file as an environ variable.
Once the environ is loaded, the passed `func` is run with this loaded environ.
:param request: The request object
:param response: The response object
:param func: The function to be run with the loaded environ with the modified filepath. Signature: (request, response, environ, path)
:return: The return of func
path = filesystem_path(self.base_path, request, self.url_base)
environ = {"__file__": path}
with open(path, 'rb') as f:
exec(compile(f.read(), path, 'exec'), environ, environ)
if func is not None:
return func(request, response, environ, path)
except IOError:
raise HTTPException(404)
def __call__(self, request, response):
def func(request, response, environ, path):
if "main" in environ:
handler = FunctionHandler(environ["main"])
handler(request, response)
wrap_pipeline(path, request, response)
raise HTTPException(500, "No main function in script %s" % path)
self._load_file(request, response, func)
def frame_handler(self, request):
This creates a FunctionHandler with one or more of the handling functions.
Used by the H2 server.
:param request: The request object used to generate the handler.
:return: A FunctionHandler object with one or more of these functions: `handle_headers`, `handle_data` or `main`
def func(request, response, environ, path):
def _main(req, resp):
handler = FunctionHandler(_main)
if "main" in environ:
handler.func = environ["main"]
if "handle_headers" in environ:
handler.handle_headers = environ["handle_headers"]
if "handle_data" in environ:
handler.handle_data = environ["handle_data"]
if handler.func is _main and not hasattr(handler, "handle_headers") and not hasattr(handler, "handle_data"):
raise HTTPException(500, "No main function or handlers in script %s" % path)
return handler
return self._load_file(request, None, func)
python_script_handler = PythonScriptHandler()
class FunctionHandler(object):
def __init__(self, func):
self.func = func
def __call__(self, request, response):
rv = self.func(request, response)
except HTTPException:
except Exception:
msg = traceback.format_exc()
raise HTTPException(500, message=msg)
if rv is not None:
if isinstance(rv, tuple):
if len(rv) == 3:
status, headers, content = rv
response.status = status
elif len(rv) == 2:
headers, content = rv
raise HTTPException(500)
content = rv
response.content = content
wrap_pipeline('', request, response)
# The generic name here is so that this can be used as a decorator
def handler(func):
return FunctionHandler(func)
class JsonHandler(object):
def __init__(self, func):
self.func = func
def __call__(self, request, response):
return FunctionHandler(self.handle_request)(request, response)
def handle_request(self, request, response):
rv = self.func(request, response)
response.headers.set("Content-Type", "application/json")
enc = json.dumps
if isinstance(rv, tuple):
rv = list(rv)
value = tuple(rv[:-1] + [enc(rv[-1])])
length = len(value[-1])
value = enc(rv)
length = len(value)
response.headers.set("Content-Length", length)
return value
def json_handler(func):
return JsonHandler(func)
class AsIsHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __call__(self, request, response):
path = filesystem_path(self.base_path, request, self.url_base)
with open(path, 'rb') as f:
wrap_pipeline(path, request, response)
response.close_connection = True
except IOError:
raise HTTPException(404)
as_is_handler = AsIsHandler()
class BasicAuthHandler(object):
def __init__(self, handler, user, password):
A Basic Auth handler
- handler: a secondary handler for the request after authentication is successful (example file_handler)
- user: string of the valid user name or None if any / all credentials are allowed
- password: string of the password required
self.user = user
self.password = password
self.handler = handler
def __call__(self, request, response):
if "authorization" not in request.headers:
response.status = 401
response.headers.set("WWW-Authenticate", "Basic")
return response
auth = Authentication(request.headers)
if self.user is not None and (self.user != auth.username or self.password != auth.password):
response.set_error(403, "Invalid username or password")
return response
return self.handler(request, response)
basic_auth_handler = BasicAuthHandler(file_handler, None, None)
class ErrorHandler(object):
def __init__(self, status):
self.status = status
def __call__(self, request, response):
class StringHandler(object):
def __init__(self, data, content_type, **headers):
"""Handler that returns a fixed data string and headers
:param data: String to use
:param content_type: Content type header to server the response with
:param headers: List of headers to send with responses"""
self.data = data
self.resp_headers = [("Content-Type", content_type)]
for k, v in iteritems(headers):
self.resp_headers.append((k.replace("_", "-"), v))
self.handler = handler(self.handle_request)
def handle_request(self, request, response):
return self.resp_headers, self.data
def __call__(self, request, response):
rv = self.handler(request, response)
return rv
class StaticHandler(StringHandler):
def __init__(self, path, format_args, content_type, **headers):
"""Handler that reads a file from a path and substitutes some fixed data
Note that *.headers files have no effect in this handler.
:param path: Path to the template file to use
:param format_args: Dictionary of values to substitute into the template file
:param content_type: Content type header to server the response with
:param headers: List of headers to send with responses"""
with open(path) as f:
data = f.read()
if format_args:
data = data % format_args
return super(StaticHandler, self).__init__(data, content_type, **headers)
@ -0,0 +1,29 @@
class NoOpLogger(object):
def critical(self, msg):
def error(self, msg):
def info(self, msg):
def warning(self, msg):
def debug(self, msg):
logger = NoOpLogger()
_set_logger = False
def set_logger(new_logger):
global _set_logger
if _set_logger:
raise Exception("Logger must be set at most once")
global logger
logger = new_logger
_set_logger = True
def get_logger():
return logger
@ -0,0 +1,559 @@
from collections import deque
import base64
import gzip as gzip_module
import hashlib
import os
import re
import time
import uuid
from six.moves import StringIO
from six import text_type, binary_type
from html import escape
except ImportError:
from cgi import escape
def resolve_content(response):
return b"".join(item for item in response.iter_content(read_file=True))
class Pipeline(object):
pipes = {}
def __init__(self, pipe_string):
self.pipe_functions = self.parse(pipe_string)
def parse(self, pipe_string):
functions = []
for item in PipeTokenizer().tokenize(pipe_string):
if not item:
if item[0] == "function":
functions.append((self.pipes[item[1]], []))
elif item[0] == "argument":
return functions
def __call__(self, request, response):
for func, args in self.pipe_functions:
response = func(request, response, *args)
return response
class PipeTokenizer(object):
def __init__(self):
#This whole class can likely be replaced by some regexps
self.state = None
def tokenize(self, string):
self.string = string
self.state = self.func_name_state
self._index = 0
while self.state:
yield self.state()
yield None
def get_char(self):
if self._index >= len(self.string):
return None
rv = self.string[self._index]
self._index += 1
return rv
def func_name_state(self):
rv = ""
while True:
char = self.get_char()
if char is None:
self.state = None
if rv:
return ("function", rv)
return None
elif char == "(":
self.state = self.argument_state
return ("function", rv)
elif char == "|":
if rv:
return ("function", rv)
rv += char
def argument_state(self):
rv = ""
while True:
char = self.get_char()
if char is None:
self.state = None
return ("argument", rv)
elif char == "\\":
rv += self.get_escape()
if rv is None:
#This should perhaps be an error instead
return ("argument", rv)
elif char == ",":
return ("argument", rv)
elif char == ")":
self.state = self.func_name_state
return ("argument", rv)
rv += char
def get_escape(self):
char = self.get_char()
escapes = {"n": "\n",
"r": "\r",
"t": "\t"}
return escapes.get(char, char)
class pipe(object):
def __init__(self, *arg_converters):
self.arg_converters = arg_converters
self.max_args = len(self.arg_converters)
self.min_args = 0
opt_seen = False
for item in self.arg_converters:
if not opt_seen:
if isinstance(item, opt):
opt_seen = True
self.min_args += 1
if not isinstance(item, opt):
raise ValueError("Non-optional argument cannot follow optional argument")
def __call__(self, f):
def inner(request, response, *args):
if not (self.min_args <= len(args) <= self.max_args):
raise ValueError("Expected between %d and %d args, got %d" %
(self.min_args, self.max_args, len(args)))
arg_values = tuple(f(x) for f, x in zip(self.arg_converters, args))
return f(request, response, *arg_values)
Pipeline.pipes[f.__name__] = inner
#We actually want the undecorated function in the main namespace
return f
class opt(object):
def __init__(self, f):
self.f = f
def __call__(self, arg):
return self.f(arg)
def nullable(func):
def inner(arg):
if arg.lower() == "null":
return None
return func(arg)
return inner
def boolean(arg):
if arg.lower() in ("true", "1"):
return True
elif arg.lower() in ("false", "0"):
return False
raise ValueError
def status(request, response, code):
"""Alter the status code.
:param code: Status code to use for the response."""
response.status = code
return response
@pipe(str, str, opt(boolean))
def header(request, response, name, value, append=False):
"""Set a HTTP header.
Replaces any existing HTTP header of the same name unless
append is set, in which case the header is appended without
:param name: Name of the header to set.
:param value: Value to use for the header.
:param append: True if existing headers should not be replaced
if not append:
response.headers.set(name, value)
response.headers.append(name, value)
return response
def trickle(request, response, delays):
"""Send the response in parts, with time delays.
:param delays: A string of delays and amounts, in bytes, of the
response to send. Each component is separated by
a colon. Amounts in bytes are plain integers, whilst
delays are floats prefixed with a single d e.g.
Would cause a 1 second delay, would then send 100 bytes
of the file, and then cause a 2 second delay, before sending
the remainder of the file.
If the last token is of the form rN, instead of sending the
remainder of the file, the previous N instructions will be
repeated until the whole file has been sent e.g.
Causes a delay of 1s, then 100 bytes to be sent, then a 2s delay
and then a further 100 bytes followed by a two second delay
until the response has been fully sent.
def parse_delays():
parts = delays.split(":")
rv = []
for item in parts:
if item.startswith("d"):
item_type = "delay"
item = item[1:]
value = float(item)
elif item.startswith("r"):
item_type = "repeat"
value = int(item[1:])
if not value % 2 == 0:
raise ValueError
item_type = "bytes"
value = int(item)
if len(rv) and rv[-1][0] == item_type:
rv[-1][1] += value
rv.append((item_type, value))
return rv
delays = parse_delays()
if not delays:
return response
content = resolve_content(response)
offset = [0]
if not ("Cache-Control" in response.headers or
"Pragma" in response.headers or
"Expires" in response.headers):
response.headers.set("Cache-Control", "no-cache, no-store, must-revalidate")
response.headers.set("Pragma", "no-cache")
response.headers.set("Expires", "0")
def add_content(delays, repeat=False):
for i, (item_type, value) in enumerate(delays):
if item_type == "bytes":
yield content[offset[0]:offset[0] + value]
offset[0] += value
elif item_type == "delay":
elif item_type == "repeat":
if i != len(delays) - 1:
while offset[0] < len(content):
for item in add_content(delays[-(value + 1):-1], True):
yield item
if not repeat and offset[0] < len(content):
yield content[offset[0]:]
response.content = add_content(delays)
return response
@pipe(nullable(int), opt(nullable(int)))
def slice(request, response, start, end=None):
"""Send a byte range of the response body
:param start: The starting offset. Follows python semantics including
negative numbers.
:param end: The ending offset, again with python semantics and None
(spelled "null" in a query string) to indicate the end of
the file.
content = resolve_content(response)[start:end]
response.content = content
response.headers.set("Content-Length", len(content))
return response
class ReplacementTokenizer(object):
def arguments(self, token):
unwrapped = token[1:-1].decode('utf8')
return ("arguments", re.split(r",\s*", unwrapped) if unwrapped else [])
def ident(self, token):
return ("ident", token.decode('utf8'))
def index(self, token):
token = token[1:-1].decode('utf8')
index = int(token)
except ValueError:
index = token
return ("index", index)
def var(self, token):
token = token[:-1].decode('utf8')
return ("var", token)
def tokenize(self, string):
assert isinstance(string, binary_type)
return self.scanner.scan(string)[0]
scanner = re.Scanner([(br"\$\w+:", var),
(br"\$?\w+", ident),
(br"\[[^\]]*\]", index),
(br"\([^)]*\)", arguments)])
class FirstWrapper(object):
def __init__(self, params):
self.params = params
def __getitem__(self, key):
if isinstance(key, text_type):
key = key.encode('iso-8859-1')
return self.params.first(key)
except KeyError:
return ""
def sub(request, response, escape_type="html"):
"""Substitute environment information about the server and request into the script.
:param escape_type: String detailing the type of escaping to use. Known values are
"html" and "none", with "html" the default for historic reasons.
The format is a very limited template language. Substitutions are
enclosed by {{ and }}. There are several available substitutions:
A simple string value and represents the primary host from which the
tests are being run.
A dictionary of available domains indexed by subdomain name.
A dictionary of lists of ports indexed by protocol.
A dictionary of parts of the request URL. Valid keys are
'server, 'scheme', 'host', 'hostname', 'port', 'path' and 'query'.
'server' is scheme://host:port, 'host' is hostname:port, and query
includes the leading '?', but other delimiters are omitted.
A dictionary of HTTP headers in the request.
header_or_default(header, default)
The value of an HTTP header, or a default value if it is absent.
For example::
{{header_or_default(X-Test, test-header-absent)}}
A dictionary of query parameters supplied with the request.
A pesudo-random UUID suitable for usage with stash
file_hash(algorithm, filepath)
The cryptographic hash of a file. Supported algorithms: md5, sha1,
sha224, sha256, sha384, and sha512. For example::
{{file_hash(md5, dom/interfaces.html)}}
The absolute path to a file inside the wpt document root
So for example in a setup running on localhost with a www
subdomain and a http server on ports 80 and 81::
{{host}} => localhost
{{domains[www]}} => www.localhost
{{ports[http][1]}} => 81
It is also possible to assign a value to a variable name, which must start
with the $ character, using the ":" syntax e.g.::
Later substitutions in the same file may then refer to the variable
by name e.g.::
content = resolve_content(response)
new_content = template(request, content, escape_type=escape_type)
response.content = new_content
return response
class SubFunctions(object):
def uuid(request):
return str(uuid.uuid4())
# Maintain a list of supported algorithms, restricted to those that are
# available on all platforms [1]. This ensures that test authors do not
# unknowingly introduce platform-specific tests.
# [1] https://docs.python.org/2/library/hashlib.html
supported_algorithms = ("md5", "sha1", "sha224", "sha256", "sha384", "sha512")
def file_hash(request, algorithm, path):
assert isinstance(algorithm, text_type)
if algorithm not in SubFunctions.supported_algorithms:
raise ValueError("Unsupported encryption algorithm: '%s'" % algorithm)
hash_obj = getattr(hashlib, algorithm)()
absolute_path = os.path.join(request.doc_root, path)
with open(absolute_path, "rb") as f:
except IOError:
# In this context, an unhandled IOError will be interpreted by the
# server as an indication that the template file is non-existent.
# Although the generic "Exception" is less precise, it avoids
# triggering a potentially-confusing HTTP 404 error in cases where
# the path to the file to be hashed is invalid.
raise Exception('Cannot open file for hash computation: "%s"' % absolute_path)
return base64.b64encode(hash_obj.digest()).strip()
def fs_path(request, path):
if not path.startswith("/"):
subdir = request.request_path[len(request.url_base):]
if "/" in subdir:
subdir = subdir.rsplit("/", 1)[0]
root_rel_path = subdir + "/" + path
root_rel_path = path[1:]
root_rel_path = root_rel_path.replace("/", os.path.sep)
absolute_path = os.path.abspath(os.path.join(request.doc_root, root_rel_path))
if ".." in os.path.relpath(absolute_path, request.doc_root):
raise ValueError("Path outside wpt root")
return absolute_path
def header_or_default(request, name, default):
return request.headers.get(name, default)
def template(request, content, escape_type="html"):
#TODO: There basically isn't any error handling here
tokenizer = ReplacementTokenizer()
variables = {}
def config_replacement(match):
content, = match.groups()
tokens = tokenizer.tokenize(content)
tokens = deque(tokens)
token_type, field = tokens.popleft()
assert isinstance(field, text_type)
if token_type == "var":
variable = field
token_type, field = tokens.popleft()
assert isinstance(field, text_type)
variable = None
if token_type != "ident":
raise Exception("unexpected token type %s (token '%r'), expected ident" % (token_type, field))
if field in variables:
value = variables[field]
elif hasattr(SubFunctions, field):
value = getattr(SubFunctions, field)
elif field == "headers":
value = request.headers
elif field == "GET":
value = FirstWrapper(request.GET)
elif field == "hosts":
value = request.server.config.all_domains
elif field == "domains":
value = request.server.config.all_domains[""]
elif field == "host":
value = request.server.config["browser_host"]
elif field in request.server.config:
value = request.server.config[field]
elif field == "location":
value = {"server": "%s://%s:%s" % (request.url_parts.scheme,
"scheme": request.url_parts.scheme,
"host": "%s:%s" % (request.url_parts.hostname,
"hostname": request.url_parts.hostname,
"port": request.url_parts.port,
"path": request.url_parts.path,
"pathname": request.url_parts.path,
"query": "?%s" % request.url_parts.query}
elif field == "url_base":
value = request.url_base
raise Exception("Undefined template variable %s" % field)
while tokens:
ttype, field = tokens.popleft()
if ttype == "index":
value = value[field]
elif ttype == "arguments":
value = value(request, *field)
raise Exception(
"unexpected token type %s (token '%r'), expected ident or arguments" % (ttype, field)
assert isinstance(value, (int, (binary_type, text_type))), tokens
if variable is not None:
variables[variable] = value
escape_func = {"html": lambda x:escape(x, quote=True),
"none": lambda x:x}[escape_type]
# Should possibly support escaping for other contexts e.g. script
# TODO: read the encoding of the response
# cgi.escape() only takes text strings in Python 3.
if isinstance(value, binary_type):
value = value.decode("utf-8")
elif isinstance(value, int):
value = text_type(value)
return escape_func(value).encode("utf-8")
template_regexp = re.compile(br"{{([^}]*)}}")
new_content = template_regexp.sub(config_replacement, content)
return new_content
def gzip(request, response):
"""This pipe gzip-encodes response data.
It sets (or overwrites) these HTTP headers:
Content-Encoding is set to gzip
Content-Length is set to the length of the compressed content
content = resolve_content(response)
response.headers.set("Content-Encoding", "gzip")
out = StringIO()
with gzip_module.GzipFile(fileobj=out, mode="w") as f:
response.content = out.getvalue()
response.headers.set("Content-Length", len(response.content))
return response
@ -0,0 +1,94 @@
from .utils import HTTPException
class RangeParser(object):
def __call__(self, header, file_size):
header = header.decode("ascii")
except UnicodeDecodeError:
raise HTTPException(400, "Non-ASCII range header value")
prefix = "bytes="
if not header.startswith(prefix):
raise HTTPException(416, message="Unrecognised range type %s" % (header,))
parts = header[len(prefix):].split(",")
ranges = []
for item in parts:
components = item.split("-")
if len(components) != 2:
raise HTTPException(416, "Bad range specifier %s" % (item))
data = []
for component in components:
if component == "":
except ValueError:
raise HTTPException(416, "Bad range specifier %s" % (item))
ranges.append(Range(data[0], data[1], file_size))
except ValueError:
raise HTTPException(416, "Bad range specifier %s" % (item))
return self.coalesce_ranges(ranges, file_size)
def coalesce_ranges(self, ranges, file_size):
rv = []
target = None
for current in reversed(sorted(ranges)):
if target is None:
target = current
new = target.coalesce(current)
target = new[0]
if len(new) > 1:
return rv[::-1]
class Range(object):
def __init__(self, lower, upper, file_size):
self.file_size = file_size
self.lower, self.upper = self._abs(lower, upper)
if self.lower >= self.upper or self.lower >= self.file_size:
raise ValueError
def __repr__(self):
return "<Range %s-%s>" % (self.lower, self.upper)
def __lt__(self, other):
return self.lower < other.lower
def __gt__(self, other):
return self.lower > other.lower
def __eq__(self, other):
return self.lower == other.lower and self.upper == other.upper
def _abs(self, lower, upper):
if lower is None and upper is None:
lower, upper = 0, self.file_size
elif lower is None:
lower, upper = max(0, self.file_size - upper), self.file_size
elif upper is None:
lower, upper = lower, self.file_size
lower, upper = lower, min(self.file_size, upper + 1)
return lower, upper
def coalesce(self, other):
assert self.file_size == other.file_size
if (self.upper < other.lower or self.lower > other.upper):
return sorted([self, other])
return [Range(min(self.lower, other.lower),
max(self.upper, other.upper) - 1,
def header_value(self):
return "bytes %i-%i/%i" % (self.lower, self.upper - 1, self.file_size)
@ -0,0 +1,688 @@
import base64
import cgi
import tempfile
from six import BytesIO, binary_type, iteritems, PY3
from six.moves.http_cookies import BaseCookie
from six.moves.urllib.parse import parse_qsl, urlsplit
from . import stash
from .utils import HTTPException, isomorphic_encode, isomorphic_decode
missing = object()
class Server(object):
"""Data about the server environment
.. attribute:: config
Environment configuration information with information about the
various servers running, their hostnames and ports.
.. attribute:: stash
Stash object holding state stored on the server between requests.
config = None
def __init__(self, request):
self._stash = None
self._request = request
def stash(self):
if self._stash is None:
address, authkey = stash.load_env_config()
self._stash = stash.Stash(self._request.url_parts.path, address, authkey)
return self._stash
class InputFile(object):
max_buffer_size = 1024*1024
def __init__(self, rfile, length):
"""File-like object used to provide a seekable view of request body data"""
self._file = rfile
self.length = length
self._file_position = 0
if length > self.max_buffer_size:
self._buf = tempfile.TemporaryFile()
self._buf = BytesIO()
def _buf_position(self):
rv = self._buf.tell()
assert rv <= self._file_position
return rv
def read(self, bytes=-1):
assert self._buf_position <= self._file_position
if bytes < 0:
bytes = self.length - self._buf_position
bytes_remaining = min(bytes, self.length - self._buf_position)
if bytes_remaining == 0:
return b""
if self._buf_position != self._file_position:
buf_bytes = min(bytes_remaining, self._file_position - self._buf_position)
old_data = self._buf.read(buf_bytes)
bytes_remaining -= buf_bytes
old_data = b""
assert bytes_remaining == 0 or self._buf_position == self._file_position, (
"Before reading buffer position (%i) didn't match file position (%i)" %
(self._buf_position, self._file_position))
new_data = self._file.read(bytes_remaining)
self._file_position += bytes_remaining
assert bytes_remaining == 0 or self._buf_position == self._file_position, (
"After reading buffer position (%i) didn't match file position (%i)" %
(self._buf_position, self._file_position))
return old_data + new_data
def tell(self):
return self._buf_position
def seek(self, offset):
if offset > self.length or offset < 0:
raise ValueError
if offset <= self._file_position:
self.read(offset - self._file_position)
def readline(self, max_bytes=None):
if max_bytes is None:
max_bytes = self.length - self._buf_position
if self._buf_position < self._file_position:
data = self._buf.readline(max_bytes)
if data.endswith(b"\n") or len(data) == max_bytes:
return data
data = b""
assert self._buf_position == self._file_position
initial_position = self._file_position
found = False
buf = []
max_bytes -= len(data)
while not found:
readahead = self.read(min(2, max_bytes))
max_bytes -= len(readahead)
for i, c in enumerate(readahead):
if c == b"\n"[0]:
found = True
if not found:
if not readahead or not max_bytes:
new_data = b"".join(buf)
data += new_data
self.seek(initial_position + len(new_data))
return data
def readlines(self):
rv = []
while True:
data = self.readline()
if data:
return rv
def __next__(self):
data = self.readline()
if data:
return data
raise StopIteration
next = __next__
def __iter__(self):
return self
class Request(object):
"""Object representing a HTTP request.
.. attribute:: doc_root
The local directory to use as a base when resolving paths
.. attribute:: route_match
Regexp match object from matching the request path to the route
selected for the request.
.. attribute:: client_address
Contains a tuple of the form (host, port) representing the client's address.
.. attribute:: protocol_version
HTTP version specified in the request.
.. attribute:: method
HTTP method in the request.
.. attribute:: request_path
Request path as it appears in the HTTP request.
.. attribute:: url_base
The prefix part of the path; typically / unless the handler has a url_base set
.. attribute:: url
Absolute URL for the request.
.. attribute:: url_parts
Parts of the requested URL as obtained by urlparse.urlsplit(path)
.. attribute:: request_line
Raw request line
.. attribute:: headers
RequestHeaders object providing a dictionary-like representation of
the request headers.
.. attribute:: raw_headers.
Dictionary of non-normalized request headers.
.. attribute:: body
Request body as a string
.. attribute:: raw_input
File-like object representing the body of the request.
.. attribute:: GET
MultiDict representing the parameters supplied with the request.
Note that these may be present on non-GET requests; the name is
chosen to be familiar to users of other systems such as PHP.
Both keys and values are binary strings.
.. attribute:: POST
MultiDict representing the request body parameters. Most parameters
are present as string values, but file uploads have file-like
values. All string values (including keys) have binary type.
.. attribute:: cookies
A Cookies object representing cookies sent with the request with a
dictionary-like interface.
.. attribute:: auth
An instance of Authentication with username and password properties
representing any credentials supplied using HTTP authentication.
.. attribute:: server
Server object containing information about the server environment.
def __init__(self, request_handler):
self.doc_root = request_handler.server.router.doc_root
self.route_match = None # Set by the router
self.client_address = request_handler.client_address
self.protocol_version = request_handler.protocol_version
self.method = request_handler.command
# Keys and values in raw headers are native strings.
self._headers = None
self.raw_headers = request_handler.headers
scheme = request_handler.server.scheme
host = self.raw_headers.get("Host")
port = request_handler.server.server_address[1]
if host is None:
host = request_handler.server.server_address[0]
if ":" in host:
host, port = host.split(":", 1)
self.request_path = request_handler.path
self.url_base = "/"
if self.request_path.startswith(scheme + "://"):
self.url = self.request_path
# TODO(#23362): Stop using native strings for URLs.
self.url = "%s://%s:%s%s" % (
scheme, host, port, self.request_path)
self.url_parts = urlsplit(self.url)
self.request_line = request_handler.raw_requestline
self.raw_input = InputFile(request_handler.rfile,
int(self.raw_headers.get("Content-Length", 0)))
self._body = None
self._GET = None
self._POST = None
self._cookies = None
self._auth = None
self.server = Server(self)
def __repr__(self):
return "<Request %s %s>" % (self.method, self.url)
def GET(self):
if self._GET is None:
kwargs = {
"keep_blank_values": True,
if PY3:
kwargs["encoding"] = "iso-8859-1"
params = parse_qsl(self.url_parts.query, **kwargs)
self._GET = MultiDict()
for key, value in params:
self._GET.add(isomorphic_encode(key), isomorphic_encode(value))
return self._GET
def POST(self):
if self._POST is None:
# Work out the post parameters
pos = self.raw_input.tell()
kwargs = {
"fp": self.raw_input,
"environ": {"REQUEST_METHOD": self.method},
"headers": self.raw_headers,
"keep_blank_values": True,
if PY3:
kwargs["encoding"] = "iso-8859-1"
fs = cgi.FieldStorage(**kwargs)
self._POST = MultiDict.from_field_storage(fs)
return self._POST
def cookies(self):
if self._cookies is None:
parser = BinaryCookieParser()
cookie_headers = self.headers.get("cookie", b"")
cookies = Cookies()
for key, value in iteritems(parser):
cookies[isomorphic_encode(key)] = CookieValue(value)
self._cookies = cookies
return self._cookies
def headers(self):
if self._headers is None:
self._headers = RequestHeaders(self.raw_headers)
return self._headers
def body(self):
if self._body is None:
pos = self.raw_input.tell()
self._body = self.raw_input.read()
return self._body
def auth(self):
if self._auth is None:
self._auth = Authentication(self.headers)
return self._auth
class H2Request(Request):
def __init__(self, request_handler):
self.h2_stream_id = request_handler.h2_stream_id
self.frames = []
super(H2Request, self).__init__(request_handler)
class RequestHeaders(dict):
"""Read-only dictionary-like API for accessing request headers.
Unlike BaseHTTPRequestHandler.headers, this class always returns all
headers with the same name (separated by commas). And it ensures all keys
(i.e. names of headers) and values have binary type.
def __init__(self, items):
for header in items.keys():
key = isomorphic_encode(header).lower()
# get all headers with the same name
values = items.getallmatchingheaders(header)
if len(values) > 1:
# collect the multiple variations of the current header
multiples = []
# loop through the values from getallmatchingheaders
for value in values:
# getallmatchingheaders returns raw header lines, so
# split to get name, value
multiples.append(isomorphic_encode(value).split(b':', 1)[1].strip())
headers = multiples
headers = [isomorphic_encode(items[header])]
dict.__setitem__(self, key, headers)
def __getitem__(self, key):
"""Get all headers of a certain (case-insensitive) name. If there is
more than one, the values are returned comma separated"""
key = isomorphic_encode(key)
values = dict.__getitem__(self, key.lower())
if len(values) == 1:
return values[0]
return b", ".join(values)
def __setitem__(self, name, value):
raise Exception
def get(self, key, default=None):
"""Get a string representing all headers with a particular value,
with multiple headers separated by a comma. If no header is found
return a default value
:param key: The header name to look up (case-insensitive)
:param default: The value to return in the case of no match
return self[key]
except KeyError:
return default
def get_list(self, key, default=missing):
"""Get all the header values for a particular field name as
a list"""
key = isomorphic_encode(key)
return dict.__getitem__(self, key.lower())
except KeyError:
if default is not missing:
return default
def __contains__(self, key):
key = isomorphic_encode(key)
return dict.__contains__(self, key.lower())
def iteritems(self):
for item in self:
yield item, self[item]
def itervalues(self):
for item in self:
yield self[item]
class CookieValue(object):
"""Representation of cookies.
Note that cookies are considered read-only and the string value
of the cookie will not change if you update the field values.
However this is not enforced.
.. attribute:: key
The name of the cookie.
.. attribute:: value
The value of the cookie
.. attribute:: expires
The expiry date of the cookie
.. attribute:: path
The path of the cookie
.. attribute:: comment
The comment of the cookie.
.. attribute:: domain
The domain with which the cookie is associated
.. attribute:: max_age
The max-age value of the cookie.
.. attribute:: secure
Whether the cookie is marked as secure
.. attribute:: httponly
Whether the cookie is marked as httponly
def __init__(self, morsel):
self.key = morsel.key
self.value = morsel.value
for attr in ["expires", "path",
"comment", "domain", "max-age",
"secure", "version", "httponly"]:
setattr(self, attr.replace("-", "_"), morsel[attr])
self._str = morsel.OutputString()
def __str__(self):
return self._str
def __repr__(self):
return self._str
def __eq__(self, other):
"""Equality comparison for cookies. Compares to other cookies
based on value alone and on non-cookies based on the equality
of self.value with the other object so that a cookie with value
"ham" compares equal to the string "ham"
if hasattr(other, "value"):
return self.value == other.value
return self.value == other
class MultiDict(dict):
"""Dictionary type that holds multiple values for each key"""
# TODO: this should perhaps also order the keys
def __init__(self):
def __setitem__(self, name, value):
dict.__setitem__(self, name, [value])
def add(self, name, value):
if name in self:
dict.__getitem__(self, name).append(value)
dict.__setitem__(self, name, [value])
def __getitem__(self, key):
"""Get the first value with a given key"""
return self.first(key)
def first(self, key, default=missing):
"""Get the first value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (throws if nothing is
if key in self and dict.__getitem__(self, key):
return dict.__getitem__(self, key)[0]
elif default is not missing:
return default
raise KeyError(key)
def last(self, key, default=missing):
"""Get the last value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (throws if nothing is
if key in self and dict.__getitem__(self, key):
return dict.__getitem__(self, key)[-1]
elif default is not missing:
return default
raise KeyError(key)
# We need to explicitly override dict.get; otherwise, it won't call
# __getitem__ and would return a list instead.
def get(self, key, default=None):
"""Get the first value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (None by default)
return self.first(key, default)
def get_list(self, key):
"""Get all values with a given key as a list
:param key: The key to lookup
if key in self:
return dict.__getitem__(self, key)
return []
def from_field_storage(cls, fs):
"""Construct a MultiDict from a cgi.FieldStorage
Note that all keys and values are binary strings.
self = cls()
if fs.list is None:
return self
for key in fs:
values = fs[key]
if not isinstance(values, list):
values = [values]
for value in values:
if not value.filename:
value = isomorphic_encode(value.value)
assert isinstance(value, cgi.FieldStorage)
self.add(isomorphic_encode(key), value)
return self
class BinaryCookieParser(BaseCookie):
"""A subclass of BaseCookie that returns values in binary strings
This is not intended to store the cookies; use Cookies instead.
def value_decode(self, val):
"""Decode value from network to (real_value, coded_value).
Override BaseCookie.value_decode.
return isomorphic_encode(val), val
def value_encode(self, val):
raise NotImplementedError('BinaryCookieParser is not for setting cookies')
def load(self, rawdata):
"""Load cookies from a binary string.
This overrides and calls BaseCookie.load. Unlike BaseCookie.load, it
does not accept dictionaries.
assert isinstance(rawdata, binary_type)
if PY3:
# BaseCookie.load expects a native string, which in Python 3 is text.
rawdata = isomorphic_decode(rawdata)
super(BinaryCookieParser, self).load(rawdata)
class Cookies(MultiDict):
"""MultiDict specialised for Cookie values
Keys and values are binary strings.
def __init__(self):
def __getitem__(self, key):
return self.last(key)
class Authentication(object):
"""Object for dealing with HTTP Authentication
.. attribute:: username
The username supplied in the HTTP Authorization
header, or None
.. attribute:: password
The password supplied in the HTTP Authorization
header, or None
Both attributes are binary strings (`str` in Py2, `bytes` in Py3), since
RFC7617 Section 2.1 does not specify the encoding for username & password
(as long it's compatible with ASCII). UTF-8 should be a relatively safe
choice if callers need to decode them as most browsers use it.
def __init__(self, headers):
self.username = None
self.password = None
auth_schemes = {b"Basic": self.decode_basic}
if "authorization" in headers:
header = headers.get("authorization")
assert isinstance(header, binary_type)
auth_type, data = header.split(b" ", 1)
if auth_type in auth_schemes:
self.username, self.password = auth_schemes[auth_type](data)
raise HTTPException(400, "Unsupported authentication scheme %s" % auth_type)
def decode_basic(self, data):
assert isinstance(data, binary_type)
decoded_data = base64.b64decode(data)
return decoded_data.split(b":", 1)
@ -0,0 +1,813 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from io import BytesIO
import json
import socket
import uuid
from hpack.struct import HeaderTuple
from hyperframe.frame import HeadersFrame, DataFrame, ContinuationFrame
from six import binary_type, text_type, integer_types, itervalues, PY3
from six.moves.http_cookies import BaseCookie, Morsel
from .constants import response_codes, h2_headers
from .logger import get_logger
from .utils import isomorphic_decode, isomorphic_encode
missing = object()
class Response(object):
"""Object representing the response to a HTTP request
:param handler: RequestHandler being used for this response
:param request: Request that this is the response for
.. attribute:: request
Request associated with this Response.
.. attribute:: encoding
The encoding to use when converting unicode to strings for output.
.. attribute:: add_required_headers
Boolean indicating whether mandatory headers should be added to the
.. attribute:: send_body_for_head_request
Boolean, default False, indicating whether the body content should be
sent when the request method is HEAD.
.. attribute:: writer
The ResponseWriter for this response
.. attribute:: status
Status tuple (code, message). Can be set to an integer in which case the
message part is filled in automatically, or a tuple (code, message) in
which case code is an int and message is a text or binary string.
.. attribute:: headers
List of HTTP headers to send with the response. Each item in the list is a
tuple of (name, value).
.. attribute:: content
The body of the response. This can either be a string or a iterable of response
parts. If it is an iterable, any item may be a string or a function of zero
parameters which, when called, returns a string."""
def __init__(self, handler, request, response_writer_cls=None):
self.request = request
self.encoding = "utf8"
self.add_required_headers = True
self.send_body_for_head_request = False
self.close_connection = False
self.logger = get_logger()
self.writer = response_writer_cls(handler, self) if response_writer_cls else ResponseWriter(handler, self)
self._status = (200, None)
self.headers = ResponseHeaders()
self.content = []
def status(self):
return self._status
def status(self, value):
if hasattr(value, "__len__"):
if len(value) != 2:
raise ValueError
code = int(value[0])
message = value[1]
# Only call str() if message is not a string type, so that we
# don't get `str(b"foo") == "b'foo'"` in Python 3.
if not isinstance(message, (binary_type, text_type)):
message = str(message)
self._status = (code, message)
self._status = (int(value), None)
def set_cookie(self, name, value, path="/", domain=None, max_age=None,
expires=None, secure=False, httponly=False, comment=None):
"""Set a cookie to be sent with a Set-Cookie header in the
:param name: name of the cookie (a binary string)
:param value: value of the cookie (a binary string, or None)
:param max_age: datetime.timedelta int representing the time (in seconds)
until the cookie expires
:param path: String path to which the cookie applies
:param domain: String domain to which the cookie applies
:param secure: Boolean indicating whether the cookie is marked as secure
:param httponly: Boolean indicating whether the cookie is marked as
:param comment: String comment
:param expires: datetime.datetime or datetime.timedelta indicating a
time or interval from now when the cookie expires
# TODO(Python 3): Convert other parameters (e.g. path) to bytes, too.
if value is None:
value = b''
max_age = 0
expires = timedelta(days=-1)
if PY3:
name = isomorphic_decode(name)
value = isomorphic_decode(value)
days = {i+1: name for i, name in enumerate(["jan", "feb", "mar",
"apr", "may", "jun",
"jul", "aug", "sep",
"oct", "nov", "dec"])}
if isinstance(expires, timedelta):
expires = datetime.utcnow() + expires
if expires is not None:
expires_str = expires.strftime("%d %%s %Y %H:%M:%S GMT")
expires_str = expires_str % days[expires.month]
expires = expires_str
if max_age is not None:
if hasattr(max_age, "total_seconds"):
max_age = int(max_age.total_seconds())
max_age = "%.0d" % max_age
m = Morsel()
def maybe_set(key, value):
if value is not None and value is not False:
m[key] = value
m.set(name, value, value)
maybe_set("path", path)
maybe_set("domain", domain)
maybe_set("comment", comment)
maybe_set("expires", expires)
maybe_set("max-age", max_age)
maybe_set("secure", secure)
maybe_set("httponly", httponly)
self.headers.append("Set-Cookie", m.OutputString())
def unset_cookie(self, name):
"""Remove a cookie from those that are being sent with the response"""
if PY3:
name = isomorphic_decode(name)
cookies = self.headers.get("Set-Cookie")
parser = BaseCookie()
for cookie in cookies:
if PY3:
# BaseCookie.load expects a text string.
cookie = isomorphic_decode(cookie)
if name in parser.keys():
del self.headers["Set-Cookie"]
for m in parser.values():
if m.key != name:
self.headers.append(("Set-Cookie", m.OutputString()))
def delete_cookie(self, name, path="/", domain=None):
"""Delete a cookie on the client by setting it to the empty string
and to expire in the past"""
self.set_cookie(name, None, path=path, domain=domain, max_age=0,
def iter_content(self, read_file=False):
"""Iterator returning chunks of response body content.
If any part of the content is a function, this will be called
and the resulting value (if any) returned.
:param read_file: boolean controlling the behaviour when content is a
file handle. When set to False the handle will be
returned directly allowing the file to be passed to
the output in small chunks. When set to True, the
entire content of the file will be returned as a
string facilitating non-streaming operations like
template substitution.
if isinstance(self.content, binary_type):
yield self.content
elif isinstance(self.content, text_type):
yield self.content.encode(self.encoding)
elif hasattr(self.content, "read"):
if read_file:
yield self.content.read()
yield self.content
for item in self.content:
if hasattr(item, "__call__"):
value = item()
value = item
if value:
yield value
def write_status_headers(self):
"""Write out the status line and headers for the response"""
for item in self.headers:
def write_content(self):
"""Write out the response content"""
if self.request.method != "HEAD" or self.send_body_for_head_request:
for item in self.iter_content():
def write(self):
"""Write the whole response"""
def set_error(self, code, message=u""):
"""Set the response status headers and return a JSON error object:
{"error": {"code": code, "message": message}}
code is an int (HTTP status code), and message is a text string.
err = {"code": code,
"message": message}
data = json.dumps({"error": err})
self.status = code
self.headers = [("Content-Type", "application/json"),
("Content-Length", len(data))]
self.content = data
if code == 500:
class MultipartContent(object):
def __init__(self, boundary=None, default_content_type=None):
self.items = []
if boundary is None:
boundary = text_type(uuid.uuid4())
self.boundary = boundary
self.default_content_type = default_content_type
def __call__(self):
boundary = b"--" + self.boundary.encode("ascii")
rv = [b"", boundary]
for item in self.items:
rv[-1] += b"--"
return b"\r\n".join(rv)
def append_part(self, data, content_type=None, headers=None):
if content_type is None:
content_type = self.default_content_type
self.items.append(MultipartPart(data, content_type, headers))
def __iter__(self):
#This is hackish; when writing the response we need an iterable
#or a string. For a multipart/byterange response we want an
#iterable that contains a single callable; the MultipartContent
#object itself
yield self
class MultipartPart(object):
def __init__(self, data, content_type=None, headers=None):
assert isinstance(data, binary_type), data
self.headers = ResponseHeaders()
if content_type is not None:
self.headers.set("Content-Type", content_type)
if headers is not None:
for name, value in headers:
if name.lower() == b"content-type":
func = self.headers.set
func = self.headers.append
func(name, value)
self.data = data
def to_bytes(self):
rv = []
for key, value in self.headers:
assert isinstance(key, binary_type)
assert isinstance(value, binary_type)
rv.append(b"%s: %s" % (key, value))
return b"\r\n".join(rv)
def _maybe_encode(s):
"""Encode a string or an int into binary data using isomorphic_encode()."""
if isinstance(s, integer_types):
return b"%i" % (s,)
return isomorphic_encode(s)
class ResponseHeaders(object):
"""Dictionary-like object holding the headers for the response"""
def __init__(self):
self.data = OrderedDict()
def set(self, key, value):
"""Set a header to a specific value, overwriting any previous header
with the same name
:param key: Name of the header to set
:param value: Value to set the header to
key = _maybe_encode(key)
value = _maybe_encode(value)
self.data[key.lower()] = (key, [value])
def append(self, key, value):
"""Add a new header with a given name, not overwriting any existing
headers with the same name
:param key: Name of the header to add
:param value: Value to set for the header
key = _maybe_encode(key)
value = _maybe_encode(value)
if key.lower() in self.data:
self.set(key, value)
def get(self, key, default=missing):
"""Get the set values for a particular header."""
key = _maybe_encode(key)
return self[key]
except KeyError:
if default is missing:
return []
return default
def __getitem__(self, key):
"""Get a list of values for a particular header
key = _maybe_encode(key)
return self.data[key.lower()][1]
def __delitem__(self, key):
key = _maybe_encode(key)
del self.data[key.lower()]
def __contains__(self, key):
key = _maybe_encode(key)
return key.lower() in self.data
def __setitem__(self, key, value):
self.set(key, value)
def __iter__(self):
for key, values in itervalues(self.data):
for value in values:
yield key, value
def items(self):
return list(self)
def update(self, items_iter):
for name, value in items_iter:
self.append(name, value)
def __repr__(self):
return repr(self.data)
class H2Response(Response):
def __init__(self, handler, request):
super(H2Response, self).__init__(handler, request, response_writer_cls=H2ResponseWriter)
def write_status_headers(self):
self.writer.write_headers(self.headers, *self.status)
# Hacky way of detecting last item in generator
def write_content(self):
"""Write out the response content"""
if self.request.method != "HEAD" or self.send_body_for_head_request:
item = None
item_iter = self.iter_content()
item = next(item_iter)
while True:
check_last = next(item_iter)
self.writer.write_data(item, last=False)
item = check_last
except StopIteration:
if item:
self.writer.write_data(item, last=True)
class H2ResponseWriter(object):
def __init__(self, handler, response):
self.socket = handler.request
self.h2conn = handler.conn
self._response = response
self._handler = handler
self.stream_ended = False
self.content_written = False
self.request = response.request
self.logger = response.logger
def write_headers(self, headers, status_code, status_message=None, stream_id=None, last=False):
Send a HEADER frame that is tracked by the local state machine.
Write a HEADER frame using the H2 Connection object, will only work if the stream is in a state to send
HEADER frames.
:param headers: List of (header, value) tuples
:param status_code: The HTTP status code of the response
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param last: Flag to signal if this is the last frame in stream.
formatted_headers = []
secondary_headers = [] # Non ':' prefixed headers are to be added afterwards
for header, value in headers:
# h2_headers are native strings
# header field names are strings of ASCII
if isinstance(header, binary_type):
header = header.decode('ascii')
# value in headers can be either string or integer
if isinstance(value, binary_type):
value = self.decode(value)
if header in h2_headers:
header = ':' + header
formatted_headers.append((header, str(value)))
secondary_headers.append((header, str(value)))
formatted_headers.append((':status', str(status_code)))
with self.h2conn as connection:
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
end_stream=last or self.request.method == "HEAD"
def write_data(self, item, last=False, stream_id=None):
Send a DATA frame that is tracked by the local state machine.
Write a DATA frame using the H2 Connection object, will only work if the stream is in a state to send
DATA frames. Uses flow control to split data into multiple data frames if it exceeds the size that can
be in a single frame.
:param item: The content of the DATA frame
:param last: Flag to signal if this is the last frame in stream.
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
if isinstance(item, (text_type, binary_type)):
data = BytesIO(self.encode(item))
data = item
# Find the length of the data
data.seek(0, 2)
data_len = data.tell()
# If the data is longer than max payload size, need to write it in chunks
payload_size = self.get_max_payload_size()
while data_len > payload_size:
self.write_data_frame(data.read(payload_size), False, stream_id)
data_len -= payload_size
payload_size = self.get_max_payload_size()
self.write_data_frame(data.read(), last, stream_id)
def write_data_frame(self, data, last, stream_id=None):
with self.h2conn as connection:
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
self.stream_ended = last
def write_push(self, promise_headers, push_stream_id=None, status=None, response_headers=None, response_data=None):
"""Write a push promise, and optionally write the push content.
This will write a push promise to the request stream. If you do not provide headers and data for the response,
then no response will be pushed, and you should push them yourself using the ID returned from this function
:param promise_headers: A list of header tuples that matches what the client would use to
request the pushed response
:param push_stream_id: The ID of the stream the response should be pushed to. If none given, will
use the next available id.
:param status: The status code of the response, REQUIRED if response_headers given
:param response_headers: The headers of the response
:param response_data: The response data.
:return: The ID of the push stream
with self.h2conn as connection:
push_stream_id = push_stream_id if push_stream_id is not None else connection.get_next_available_stream_id()
connection.push_stream(self.request.h2_stream_id, push_stream_id, promise_headers)
has_data = response_data is not None
if response_headers is not None:
assert status is not None
self.write_headers(response_headers, status, stream_id=push_stream_id, last=not has_data)
if has_data:
self.write_data(response_data, last=True, stream_id=push_stream_id)
return push_stream_id
def end_stream(self, stream_id=None):
"""Ends the stream with the given ID, or the one that request was made on if no ID given."""
with self.h2conn as connection:
connection.end_stream(stream_id if stream_id is not None else self.request.h2_stream_id)
self.stream_ended = True
def write_raw_header_frame(self, headers, stream_id=None, end_stream=False, end_headers=False, frame_cls=HeadersFrame):
Ignores the statemachine of the stream and sends a HEADER frame regardless.
Unlike `write_headers`, this does not check to see if a stream is in the correct state to have HEADER frames
sent through to it. It will build a HEADER frame and send it without using the H2 Connection object other than
to HPACK encode the headers.
:param headers: List of (header, value) tuples
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_stream: Set to True to add END_STREAM flag to frame
:param end_headers: Set to True to add END_HEADERS flag to frame
if not stream_id:
stream_id = self.request.h2_stream_id
header_t = []
for header, value in headers:
header_t.append(HeaderTuple(header, value))
with self.h2conn as connection:
frame = frame_cls(stream_id, data=connection.encoder.encode(header_t))
if end_stream:
self.stream_ended = True
if end_headers:
data = frame.serialize()
def write_raw_data_frame(self, data, stream_id=None, end_stream=False):
Ignores the statemachine of the stream and sends a DATA frame regardless.
Unlike `write_data`, this does not check to see if a stream is in the correct state to have DATA frames
sent through to it. It will build a DATA frame and send it without using the H2 Connection object. It will
not perform any flow control checks.
:param data: The data to be sent in the frame
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_stream: Set to True to add END_STREAM flag to frame
if not stream_id:
stream_id = self.request.h2_stream_id
frame = DataFrame(stream_id, data=data)
if end_stream:
self.stream_ended = True
data = frame.serialize()
def write_raw_continuation_frame(self, headers, stream_id=None, end_headers=False):
Ignores the statemachine of the stream and sends a CONTINUATION frame regardless.
This provides the ability to create and write a CONTINUATION frame to the stream, which is not exposed by
`write_headers` as the h2 library handles the split between HEADER and CONTINUATION internally. Will perform
HPACK encoding on the headers.
:param headers: List of (header, value) tuples
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_headers: Set to True to add END_HEADERS flag to frame
self.write_raw_header_frame(headers, stream_id=stream_id, end_headers=end_headers, frame_cls=ContinuationFrame)
def get_max_payload_size(self, stream_id=None):
"""Returns the maximum size of a payload for the given stream."""
stream_id = stream_id if stream_id is not None else self.request.h2_stream_id
with self.h2conn as connection:
return min(connection.remote_settings.max_frame_size, connection.local_flow_control_window(stream_id)) - 9
def write(self, connection):
self.content_written = True
data = connection.data_to_send()
def write_raw(self, raw_data):
"""Used for sending raw bytes/data through the socket"""
self.content_written = True
def decode(self, data):
"""Convert bytes to unicode according to response.encoding."""
if isinstance(data, binary_type):
return data.decode(self._response.encoding)
elif isinstance(data, text_type):
return data
raise ValueError(type(data))
def encode(self, data):
"""Convert unicode to bytes according to response.encoding."""
if isinstance(data, binary_type):
return data
elif isinstance(data, text_type):
return data.encode(self._response.encoding)
raise ValueError
class ResponseWriter(object):
"""Object providing an API to write out a HTTP response.
:param handler: The RequestHandler being used.
:param response: The Response associated with this writer."""
def __init__(self, handler, response):
self._wfile = handler.wfile
self._response = response
self._handler = handler
self._status_written = False
self._headers_seen = set()
self._headers_complete = False
self.content_written = False
self.request = response.request
self.file_chunk_size = 32 * 1024
self.default_status = 200
def _seen_header(self, name):
return self.encode(name.lower()) in self._headers_seen
def write_status(self, code, message=None):
"""Write out the status line of a response.
:param code: The integer status code of the response.
:param message: The message of the response. Defaults to the message commonly used
with the status code."""
if message is None:
if code in response_codes:
message = response_codes[code][0]
message = ''
self.write(b"%s %d %s\r\n" %
(isomorphic_encode(self._response.request.protocol_version), code, isomorphic_encode(message)))
self._status_written = True
def write_header(self, name, value):
"""Write out a single header for the response.
If a status has not been written, a default status will be written (currently 200)
:param name: Name of the header field
:param value: Value of the header field
:return: A boolean indicating whether the write succeeds
if not self._status_written:
if not self.write(name):
return False
if not self.write(b": "):
return False
if isinstance(value, int):
if not self.write(text_type(value)):
return False
elif not self.write(value):
return False
return self.write(b"\r\n")
def write_default_headers(self):
for name, f in [("Server", self._handler.version_string),
("Date", self._handler.date_time_string)]:
if not self._seen_header(name):
if not self.write_header(name, f()):
return False
if (isinstance(self._response.content, (binary_type, text_type)) and
not self._seen_header("content-length")):
#Would be nice to avoid double-encoding here
if not self.write_header("Content-Length", len(self.encode(self._response.content))):
return False
return True
def end_headers(self):
"""Finish writing headers and write the separator.
Unless add_required_headers on the response is False,
this will also add HTTP-mandated headers that have not yet been supplied
to the response headers.
:return: A boolean indicating whether the write succeeds
if self._response.add_required_headers:
if not self.write_default_headers():
return False
if not self.write("\r\n"):
return False
if not self._seen_header("content-length"):
self._response.close_connection = True
self._headers_complete = True
return True
def write_content(self, data):
"""Write the body of the response.
HTTP-mandated headers will be automatically added with status default to 200 if they have
not been explicitly set.
:return: A boolean indicating whether the write succeeds
if not self._status_written:
if not self._headers_complete:
self._response.content = data
return self.write_raw_content(data)
def write_raw_content(self, data):
"""Writes the data 'as is'"""
if data is None:
raise ValueError('data cannot be None')
if isinstance(data, (text_type, binary_type)):
# Deliberately allows both text and binary types. See `self.encode`.
return self.write(data)
return self.write_content_file(data)
def write(self, data):
"""Write directly to the response, converting unicode to bytes
according to response.encoding.
:return: A boolean indicating whether the write succeeds
self.content_written = True
return True
except socket.error:
# This can happen if the socket got closed by the remote end
return False
def write_content_file(self, data):
"""Write a file-like object directly to the response in chunks."""
self.content_written = True
success = True
while True:
buf = data.read(self.file_chunk_size)
if not buf:
success = False
except socket.error:
success = False
return success
def encode(self, data):
"""Convert unicode to bytes according to response.encoding."""
if isinstance(data, binary_type):
return data
elif isinstance(data, text_type):
return data.encode(self._response.encoding)
raise ValueError("data %r should be text or binary, but is %s" % (data, type(data)))
@ -0,0 +1,179 @@
import itertools
import re
import sys
from .logger import get_logger
from six import binary_type, text_type
any_method = object()
class RouteTokenizer(object):
def literal(self, scanner, token):
return ("literal", token)
def slash(self, scanner, token):
return ("slash", None)
def group(self, scanner, token):
return ("group", token[1:-1])
def star(self, scanner, token):
return ("star", token[1:-3])
def scan(self, input_str):
scanner = re.Scanner([(r"/", self.slash),
(r"{\w*}", self.group),
(r"\*", self.star),
(r"(?:\\.|[^{\*/])*", self.literal),])
return scanner.scan(input_str)
class RouteCompiler(object):
def __init__(self):
def reset(self):
self.star_seen = False
def compile(self, tokens):
func_map = {"slash":self.process_slash,
re_parts = ["^"]
if not tokens or tokens[0][0] != "slash":
tokens = itertools.chain([("slash", None)], tokens)
for token in tokens:
if self.star_seen:
return re.compile("".join(re_parts))
def process_literal(self, token):
return re.escape(token[1])
def process_slash(self, token):
return "/"
def process_group(self, token):
if self.star_seen:
raise ValueError("Group seen after star in regexp")
return "(?P<%s>[^/]+)" % token[1]
def process_star(self, token):
if self.star_seen:
raise ValueError("Star seen after star in regexp")
self.star_seen = True
return "(.*"
def compile_path_match(route_pattern):
"""tokens: / or literal or match or *"""
tokenizer = RouteTokenizer()
tokens, unmatched = tokenizer.scan(route_pattern)
assert unmatched == "", unmatched
compiler = RouteCompiler()
return compiler.compile(tokens)
class Router(object):
"""Object for matching handler functions to requests.
:param doc_root: Absolute path of the filesystem location from
which to serve tests
:param routes: Initial routes to add; a list of three item tuples
(method, path_pattern, handler_function), defined
as for register()
def __init__(self, doc_root, routes):
self.doc_root = doc_root
self.routes = []
self.logger = get_logger()
# Add the doc_root to the Python path, so that any Python handler can
# correctly locate helper scripts (see RFC_TO_BE_LINKED).
# TODO: In a perfect world, Router would not need to know about this
# and the handler itself would take care of it. Currently, however, we
# treat handlers like functions and so there's no easy way to do that.
if self.doc_root not in sys.path:
sys.path.insert(0, self.doc_root)
for route in reversed(routes):
def register(self, methods, path, handler):
r"""Register a handler for a set of paths.
:param methods: Set of methods this should match. "*" is a
special value indicating that all methods should
be matched.
:param path_pattern: Match pattern that will be used to determine if
a request path matches this route. Match patterns
consist of either literal text, match groups,
denoted {name}, which match any character except /,
and, at most one \*, which matches and character and
creates a match group to the end of the string.
If there is no leading "/" on the pattern, this is
automatically implied. For example::
Would match `/api/test/data.json` or
`/api/test/test2/data.json`, but not `/api/test/data.py`.
The match groups are made available in the request object
as a dictionary through the route_match property. For
example, given the route pattern above and the path
`/api/test/data.json`, the route_match property would
{"resource": "test", "*": "data.json"}
:param handler: Function that will be called to process matching
requests. This must take two parameters, the request
object and the response object.
if isinstance(methods, (binary_type, text_type)) or methods is any_method:
methods = [methods]
for method in methods:
self.routes.append((method, compile_path_match(path), handler))
self.logger.debug("Route pattern: %s" % self.routes[-1][1].pattern)
def get_handler(self, request):
"""Get a handler for a request or None if there is no handler.
:param request: Request to get a handler for.
:rtype: Callable or None
for method, regexp, handler in reversed(self.routes):
if (request.method == method or
method in (any_method, "*") or
(request.method == "HEAD" and method == "GET")):
m = regexp.match(request.url_parts.path)
if m:
if not hasattr(handler, "__class__"):
name = handler.__name__
name = handler.__class__.__name__
self.logger.debug("Found handler %s" % name)
match_parts = m.groupdict().copy()
if len(match_parts) < len(m.groups()):
match_parts["*"] = m.groups()[-1]
request.route_match = match_parts
return handler
return None
@ -0,0 +1,6 @@
from . import handlers
from .router import any_method
routes = [(any_method, "*.py", handlers.python_script_handler),
("GET", "*.asis", handlers.as_is_handler),
("GET", "*", handlers.file_handler),
@ -0,0 +1,896 @@
from six.moves import BaseHTTPServer
import errno
import os
import socket
from six.moves.socketserver import ThreadingMixIn
import ssl
import sys
import threading
import time
import traceback
from six import binary_type, text_type
import uuid
from collections import OrderedDict
from six.moves.queue import Queue
from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.events import RequestReceived, ConnectionTerminated, DataReceived, StreamReset, StreamEnded
from h2.exceptions import StreamClosedError, ProtocolError
from h2.settings import SettingCodes
from h2.utilities import extract_method_header
from six.moves.urllib.parse import urlsplit, urlunsplit
from mod_pywebsocket import dispatch
from mod_pywebsocket.handshake import HandshakeException
from . import routes as default_routes
from .config import ConfigBuilder
from .logger import get_logger
from .request import Server, Request, H2Request
from .response import Response, H2Response
from .router import Router
from .utils import HTTPException, isomorphic_decode, isomorphic_encode
from .constants import h2_headers
from .ws_h2_handshake import WsH2Handshaker
# We need to stress test that browsers can send/receive many headers (there is
# no specified limit), but the Python stdlib has an arbitrary limit of 100
# headers. Hitting the limit would produce an exception that is silently caught
# in Python 2 but leads to HTTP 431 in Python 3, so we monkey patch it higher.
# https://bugs.python.org/issue26586
# https://github.com/web-platform-tests/wpt/pull/24451
from six.moves import http_client
assert isinstance(getattr(http_client, '_MAXHEADERS'), int)
setattr(http_client, '_MAXHEADERS', 512)
HTTP server designed for testing purposes.
The server is designed to provide flexibility in the way that
requests are handled, and to provide control both of exactly
what bytes are put on the wire for the response, and in the
timing of sending those bytes.
The server is based on the stdlib HTTPServer, but with some
notable differences in the way that requests are processed.
Overall processing is handled by a WebTestRequestHandler,
which is a subclass of BaseHTTPRequestHandler. This is responsible
for parsing the incoming request. A RequestRewriter is then
applied and may change the request data if it matches a
supplied rule.
Once the request data had been finalised, Request and Response
objects are constructed. These are used by the other parts of the
system to read information about the request and manipulate the
Each request is handled by a particular handler function. The
mapping between Request and the appropriate handler is determined
by a Router. By default handlers are installed to interpret files
under the document root with .py extensions as executable python
files (see handlers.py for the api for such files), .asis files as
bytestreams to be sent literally and all other files to be served
The handler functions are responsible for either populating the
fields of the response object, which will then be written when the
handler returns, or for directly writing to the output stream.
class RequestRewriter(object):
def __init__(self, rules):
"""Object for rewriting the request path.
:param rules: Initial rules to add; a list of three item tuples
(method, input_path, output_path), defined as for
self.rules = {}
for rule in reversed(rules):
self.logger = get_logger()
def register(self, methods, input_path, output_path):
"""Register a rewrite rule.
:param methods: Set of methods this should match. "*" is a
special value indicating that all methods should
be matched.
:param input_path: Path to match for the initial request.
:param output_path: Path to replace the input path with in
the request.
if isinstance(methods, (binary_type, text_type)):
methods = [methods]
self.rules[input_path] = (methods, output_path)
def rewrite(self, request_handler):
"""Rewrite the path in a BaseHTTPRequestHandler instance, if
it matches a rule.
:param request_handler: BaseHTTPRequestHandler for which to
rewrite the request.
split_url = urlsplit(request_handler.path)
if split_url.path in self.rules:
methods, destination = self.rules[split_url.path]
if "*" in methods or request_handler.command in methods:
self.logger.debug("Rewriting request path %s to %s" %
(request_handler.path, destination))
new_url = list(split_url)
new_url[2] = destination
new_url = urlunsplit(new_url)
request_handler.path = new_url
class WebTestServer(ThreadingMixIn, BaseHTTPServer.HTTPServer):
allow_reuse_address = True
acceptable_errors = (errno.EPIPE, errno.ECONNABORTED)
request_queue_size = 2000
# Ensure that we don't hang on shutdown waiting for requests
daemon_threads = True
def __init__(self, server_address, request_handler_cls,
router, rewriter, bind_address, ws_doc_root=None,
config=None, use_ssl=False, key_file=None, certificate=None,
encrypt_after_connect=False, latency=None, http2=False, **kwargs):
"""Server for HTTP(s) Requests
:param server_address: tuple of (server_name, port)
:param request_handler_cls: BaseHTTPRequestHandler-like class to use for
handling requests.
:param router: Router instance to use for matching requests to handler
:param rewriter: RequestRewriter-like instance to use for preprocessing
requests before they are routed
:param config: Dictionary holding environment configuration settings for
handlers to read, or None to use the default values.
:param use_ssl: Boolean indicating whether the server should use SSL
:param key_file: Path to key file to use if SSL is enabled.
:param certificate: Path to certificate to use if SSL is enabled.
:param ws_doc_root: Document root for websockets
:param encrypt_after_connect: For each connection, don't start encryption
until a CONNECT message has been received.
This enables the server to act as a
:param bind_address True to bind the server to both the IP address and
port specified in the server_address parameter.
False to bind the server only to the port in the
server_address parameter, but not to the address.
:param latency: Delay in ms to wait before serving each response, or
callable that returns a delay in ms
self.router = router
self.rewriter = rewriter
self.scheme = "http2" if http2 else "https" if use_ssl else "http"
self.logger = get_logger()
self.latency = latency
if bind_address:
hostname_port = server_address
hostname_port = ("",server_address[1])
#super doesn't work here because BaseHTTPServer.HTTPServer is old-style
BaseHTTPServer.HTTPServer.__init__(self, hostname_port, request_handler_cls, **kwargs)
if config is not None:
Server.config = config
self.logger.debug("Using default configuration")
with ConfigBuilder(browser_host=server_address[0],
ports={"http": [self.server_address[1]]}) as config:
assert config["ssl_config"] is None
Server.config = config
self.ws_doc_root = ws_doc_root
self.key_file = key_file
self.certificate = certificate
self.encrypt_after_connect = use_ssl and encrypt_after_connect
if use_ssl and not encrypt_after_connect:
if http2:
ssl_context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(keyfile=self.key_file, certfile=self.certificate)
self.socket = ssl_context.wrap_socket(self.socket,
self.socket = ssl.wrap_socket(self.socket,
def handle_error(self, request, client_address):
error = sys.exc_info()[1]
if ((isinstance(error, socket.error) and
isinstance(error.args, tuple) and
error.args[0] in self.acceptable_errors) or
(isinstance(error, IOError) and
error.errno in self.acceptable_errors)):
pass # remote hang up before the result is sent
class BaseWebTestRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""RequestHandler for WebTestHttpd"""
def __init__(self, *args, **kwargs):
self.logger = get_logger()
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
def finish_handling_h1(self, request_line_is_valid):
request = Request(self)
response = Response(self, request)
if request.method == "CONNECT":
if not request_line_is_valid:
self.logger.debug("%s %s" % (request.method, request.request_path))
handler = self.server.router.get_handler(request)
self.finish_handling(request, response, handler)
def finish_handling(self, request, response, handler):
# If the handler we used for the request had a non-default base path
# set update the doc_root of the request to reflect this
if hasattr(handler, "base_path") and handler.base_path:
request.doc_root = handler.base_path
if hasattr(handler, "url_base") and handler.url_base != "/":
request.url_base = handler.url_base
if self.server.latency is not None:
if callable(self.server.latency):
latency = self.server.latency()
latency = self.server.latency
self.logger.warning("Latency enabled. Sleeping %i ms" % latency)
time.sleep(latency / 1000.)
if handler is None:
self.logger.debug("No Handler found!")
handler(request, response)
except HTTPException as e:
response.set_error(e.code, e.message)
except Exception as e:
self.respond_with_error(response, e)
self.logger.debug("%i %s %s (%s) %i" % (response.status[0],
if not response.writer.content_written:
# If a python handler has been used, the old ones won't send a END_STR data frame, so this
# allows for backwards compatibility by accounting for these handlers that don't close streams
if isinstance(response, H2Response) and not response.writer.stream_ended:
# If we want to remove this in the future, a solution is needed for
# scripts that produce a non-string iterable of content, since these
# can't set a Content-Length header. A notable example of this kind of
# problem is with the trickle pipe i.e. foo.js?pipe=trickle(d1)
if response.close_connection:
self.close_connection = True
if not self.close_connection:
# Ensure that the whole request has been read from the socket
def handle_connect(self, response):
self.logger.debug("Got CONNECT")
response.status = 200
if self.server.encrypt_after_connect:
self.logger.debug("Enabling SSL for connection")
self.request = ssl.wrap_socket(self.connection,
def respond_with_error(self, response, e):
message = str(e)
if message:
err = [message]
err = []
response.set_error(500, "\n".join(err))
class Http2WebTestRequestHandler(BaseWebTestRequestHandler):
protocol_version = "HTTP/2.0"
def handle_one_request(self):
This is the main HTTP/2.0 Handler.
When a browser opens a connection to the server
on the HTTP/2.0 port, the server enters this which will initiate the h2 connection
and keep running throughout the duration of the interaction, and will read/write directly
from the socket.
Because there can be multiple H2 connections active at the same
time, a UUID is created for each so that it is easier to tell them apart in the logs.
config = H2Configuration(client_side=False)
self.conn = H2ConnectionGuard(H2Connection(config=config))
self.close_connection = False
# Generate a UUID to make it easier to distinguish different H2 connection debug messages
self.uid = str(uuid.uuid4())[:8]
self.logger.debug('(%s) Initiating h2 Connection' % self.uid)
with self.conn as connection:
# Bootstrapping WebSockets with HTTP/2 specification requires
# ENABLE_CONNECT_PROTOCOL to be set in order to enable WebSocket
# over HTTP/2
new_settings = dict(connection.local_settings)
new_settings[SettingCodes.ENABLE_CONNECT_PROTOCOL] = 1
data = connection.data_to_send()
window_size = connection.remote_settings.initial_window_size
# Dict of { stream_id: (thread, queue) }
stream_queues = {}
while not self.close_connection:
data = self.request.recv(window_size)
if data == '':
self.logger.debug('(%s) Socket Closed' % self.uid)
self.close_connection = True
with self.conn as connection:
frames = connection.receive_data(data)
window_size = connection.remote_settings.initial_window_size
self.logger.debug('(%s) Frames Received: ' % self.uid + str(frames))
for frame in frames:
if isinstance(frame, ConnectionTerminated):
self.logger.debug('(%s) Connection terminated by remote peer ' % self.uid)
self.close_connection = True
# Flood all the streams with connection terminated, this will cause them to stop
for stream_id, (thread, queue) in stream_queues.items():
elif hasattr(frame, 'stream_id'):
if frame.stream_id not in stream_queues:
queue = Queue()
stream_queues[frame.stream_id] = (self.start_stream_thread(frame, queue), queue)
if isinstance(frame, StreamEnded) or (hasattr(frame, "stream_ended") and frame.stream_ended):
del stream_queues[frame.stream_id]
except (socket.timeout, socket.error) as e:
self.logger.error('(%s) Closing Connection - \n%s' % (self.uid, str(e)))
if not self.close_connection:
self.close_connection = True
for stream_id, (thread, queue) in stream_queues.items():
except Exception as e:
self.logger.error('(%s) Unexpected Error - \n%s' % (self.uid, str(e)))
for stream_id, (thread, queue) in stream_queues.items():
def _is_extended_connect_frame(self, frame):
if not isinstance(frame, RequestReceived):
return False
method = extract_method_header(frame.headers)
if method != b"CONNECT":
return False
protocol = ""
for key, value in frame.headers:
if key in (b':protocol', u':protocol'):
protocol = isomorphic_encode(value)
if protocol != b"websocket":
raise ProtocolError("Invalid protocol %s with CONNECT METHOD" % (protocol,))
return True
def start_stream_thread(self, frame, queue):
This starts a new thread to handle frames for a specific stream.
:param frame: The first frame on the stream
:param queue: A queue object that the thread will use to check for new frames
:return: The thread object that has already been started
if self._is_extended_connect_frame(frame):
target = Http2WebTestRequestHandler._stream_ws_thread
target = Http2WebTestRequestHandler._stream_thread
t = threading.Thread(
args=(self, frame.stream_id, queue)
return t
def _stream_ws_thread(self, stream_id, queue):
frame = queue.get(True, None)
rfile, wfile = os.pipe()
rfile, wfile = os.fdopen(rfile, 'rb'), os.fdopen(wfile, 'wb', 0) # needs to be unbuffer for websockets
stream_handler = H2HandlerCopy(self, frame, rfile)
h2request = H2Request(stream_handler)
h2response = H2Response(stream_handler, h2request)
dispatcher = dispatch.Dispatcher(self.server.ws_doc_root, None, False)
if not dispatcher.get_handler_suite(stream_handler.path):
request_wrapper = _WebSocketRequest(stream_handler, h2response)
handshaker = WsH2Handshaker(request_wrapper, dispatcher)
except HandshakeException as e:
self.logger.info('Handshake failed for error: %s', e)
# h2 Handshaker prepares the headers but does not send them down the
# wire. Flush the headers here.
request_wrapper._dispatcher = dispatcher
# we need two threads:
# - one to handle the frame queue
# - one to handle the request (dispatcher.transfer_data is blocking)
# the alternative is to have only one (blocking) thread. That thread
# will call transfer_data. That would require a special case in
# handle_one_request, to bypass the queue and write data to wfile
# directly.
t = threading.Thread(
args=(self, request_wrapper, stream_handler, queue)
while not self.close_connection:
frame = queue.get(True, None)
if isinstance(frame, DataReceived):
if frame.stream_ended:
raise NotImplementedError("frame.stream_ended")
elif frame is None or isinstance(frame, (StreamReset, StreamEnded, ConnectionTerminated)):
self.logger.debug('(%s - %s) Stream Reset, Thread Closing' % (self.uid, stream_id))
def _stream_ws_sub_thread(self, request, stream_handler, queue):
dispatcher = request._dispatcher
stream_id = stream_handler.h2_stream_id
with stream_handler.conn as connection:
data = connection.data_to_send()
except StreamClosedError: # maybe the stream has already been closed
def _stream_thread(self, stream_id, queue):
This thread processes frames for a specific stream. It waits for frames to be placed
in the queue, and processes them. When it receives a request frame, it will start processing
immediately, even if there are data frames to follow. One of the reasons for this is that it
can detect invalid requests before needing to read the rest of the frames.
# The file-like pipe object that will be used to share data to request object if data is received
wfile = None
request = None
response = None
req_handler = None
while not self.close_connection:
# Wait for next frame, blocking
frame = queue.get(True, None)
self.logger.debug('(%s - %s) %s' % (self.uid, stream_id, str(frame)))
if isinstance(frame, RequestReceived):
rfile, wfile = os.pipe()
rfile, wfile = os.fdopen(rfile, 'rb'), os.fdopen(wfile, 'wb')
stream_handler = H2HandlerCopy(self, frame, rfile)
request = H2Request(stream_handler)
response = H2Response(stream_handler, request)
req_handler = stream_handler.server.router.get_handler(request)
if hasattr(req_handler, "frame_handler"):
# Convert this to a handler that will utilise H2 specific functionality, such as handling individual frames
req_handler = self.frame_handler(request, response, req_handler)
if hasattr(req_handler, 'handle_headers'):
req_handler.handle_headers(frame, request, response)
elif isinstance(frame, DataReceived):
if hasattr(req_handler, 'handle_data'):
req_handler.handle_data(frame, request, response)
if frame.stream_ended:
elif frame is None or isinstance(frame, (StreamReset, StreamEnded, ConnectionTerminated)):
self.logger.debug('(%s - %s) Stream Reset, Thread Closing' % (self.uid, stream_id))
if request is not None:
if hasattr(frame, "stream_ended") and frame.stream_ended:
self.finish_handling(request, response, req_handler)
def frame_handler(self, request, response, handler):
return handler.frame_handler(request)
except HTTPException as e:
response.set_error(e.code, e.message)
except Exception as e:
self.respond_with_error(response, e)
class H2ConnectionGuard(object):
"""H2Connection objects are not threadsafe, so this keeps thread safety"""
lock = threading.Lock()
def __init__(self, obj):
assert isinstance(obj, H2Connection)
self.obj = obj
def __enter__(self):
return self.obj
def __exit__(self, exception_type, exception_value, traceback):
class H2Headers(dict):
def __init__(self, headers):
self.raw_headers = OrderedDict()
for key, val in headers:
key = isomorphic_decode(key)
val = isomorphic_decode(val)
self.raw_headers[key] = val
dict.__setitem__(self, self._convert_h2_header_to_h1(key), val)
def _convert_h2_header_to_h1(self, header_key):
if header_key[1:] in h2_headers and header_key[0] == ':':
return header_key[1:]
return header_key
# TODO This does not seem relevant for H2 headers, so using a dummy function for now
def getallmatchingheaders(self, header):
return ['dummy function']
class H2HandlerCopy(object):
def __init__(self, handler, req_frame, rfile):
self.headers = H2Headers(req_frame.headers)
self.command = self.headers['method']
self.path = self.headers['path']
self.h2_stream_id = req_frame.stream_id
self.server = handler.server
self.protocol_version = handler.protocol_version
self.client_address = handler.client_address
self.raw_requestline = ''
self.rfile = rfile
self.request = handler.request
self.conn = handler.conn
class Http1WebTestRequestHandler(BaseWebTestRequestHandler):
protocol_version = "HTTP/1.1"
def handle_one_request(self):
response = None
self.close_connection = False
request_line_is_valid = self.get_request_line()
if self.close_connection:
request_is_valid = self.parse_request()
if not request_is_valid:
#parse_request() actually sends its own error responses
except socket.timeout as e:
self.log_error("Request timed out: %r", e)
self.close_connection = True
except Exception:
err = traceback.format_exc()
if response:
response.set_error(500, err)
def get_request_line(self):
self.raw_requestline = self.rfile.readline(65537)
except socket.error:
self.close_connection = True
return False
if len(self.raw_requestline) > 65536:
self.requestline = ''
self.request_version = ''
self.command = ''
return False
if not self.raw_requestline:
self.close_connection = True
return True
class WebTestHttpd(object):
:param host: Host from which to serve (default:
:param port: Port from which to serve (default: 8000)
:param server_cls: Class to use for the server (default depends on ssl vs non-ssl)
:param handler_cls: Class to use for the RequestHandler
:param use_ssl: Use a SSL server if no explicit server_cls is supplied
:param key_file: Path to key file to use if ssl is enabled
:param certificate: Path to certificate file to use if ssl is enabled
:param encrypt_after_connect: For each connection, don't start encryption
until a CONNECT message has been received.
This enables the server to act as a
:param router_cls: Router class to use when matching URLs to handlers
:param doc_root: Document root for serving files
:param ws_doc_root: Document root for websockets
:param routes: List of routes with which to initialize the router
:param rewriter_cls: Class to use for request rewriter
:param rewrites: List of rewrites with which to initialize the rewriter_cls
:param config: Dictionary holding environment configuration settings for
handlers to read, or None to use the default values.
:param bind_address: Boolean indicating whether to bind server to IP address.
:param latency: Delay in ms to wait before serving each response, or
callable that returns a delay in ms
HTTP server designed for testing scenarios.
Takes a router class which provides one method get_handler which takes a Request
and returns a handler function.
.. attribute:: host
The host name or ip address of the server
.. attribute:: port
The port on which the server is running
.. attribute:: router
The Router object used to associate requests with resources for this server
.. attribute:: rewriter
The Rewriter object used for URL rewriting
.. attribute:: use_ssl
Boolean indicating whether the server is using ssl
.. attribute:: started
Boolean indicating whether the server is running
def __init__(self, host="", port=8000,
server_cls=None, handler_cls=Http1WebTestRequestHandler,
use_ssl=False, key_file=None, certificate=None, encrypt_after_connect=False,
router_cls=Router, doc_root=os.curdir, ws_doc_root=None, routes=None,
rewriter_cls=RequestRewriter, bind_address=True, rewrites=None,
latency=None, config=None, http2=False):
if routes is None:
routes = default_routes.routes
self.host = host
self.router = router_cls(doc_root, routes)
self.rewriter = rewriter_cls(rewrites if rewrites is not None else [])
self.use_ssl = use_ssl
self.http2 = http2
self.logger = get_logger()
if server_cls is None:
server_cls = WebTestServer
if use_ssl:
if not os.path.exists(key_file):
raise ValueError("SSL certificate not found: {}".format(key_file))
if not os.path.exists(certificate):
raise ValueError("SSL key not found: {}".format(certificate))
self.httpd = server_cls((host, port),
self.started = False
_host, self.port = self.httpd.socket.getsockname()
except Exception:
self.logger.critical("Failed to start HTTP server on port %s; "
"is something already using that port?" % port)
def start(self, block=False):
"""Start the server.
:param block: True to run the server on the current thread, blocking,
False to run on a separate thread."""
http_type = "http2" if self.http2 else "https" if self.use_ssl else "http"
self.logger.info("Starting %s server on %s:%s" % (http_type, self.host, self.port))
self.started = True
if block:
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.setDaemon(True) # don't hang on exit
def stop(self):
Stops the server.
If the server is not running, this method has no effect.
if self.started:
self.server_thread = None
self.logger.info("Stopped http server on %s:%s" % (self.host, self.port))
except AttributeError:
self.started = False
self.httpd = None
def get_url(self, path="/", query=None, fragment=None):
if not self.started:
return None
return urlunsplit(("http" if not self.use_ssl else "https",
"%s:%s" % (self.host, self.port),
path, query, fragment))
class _WebSocketConnection(object):
def __init__(self, request_handler, response):
"""Mimic mod_python mp_conn.
:param request_handler: A H2HandlerCopy instance.
:param response: A H2Response instance.
self._request_handler = request_handler
self._response = response
self.remote_addr = self._request_handler.client_address
def write(self, data):
self._response.writer.write_data(data, False)
def read(self, length):
return self._request_handler.rfile.read(length)
class _WebSocketRequest(object):
def __init__(self, request_handler, response):
"""Mimic mod_python request.
:param request_handler: A H2HandlerCopy instance.
:param response: A H2Response instance.
self.connection = _WebSocketConnection(request_handler, response)
self._response = response
self.uri = request_handler.path
self.unparsed_uri = request_handler.path
self.method = request_handler.command
# read headers from request_handler
self.headers_in = request_handler.headers
# write headers directly into H2Response
self.headers_out = response.headers
# proxies status to H2Response
def status(self):
return self._response.status
def status(self, status):
self._response.status = status
@ -0,0 +1,14 @@
from .base import NoSSLEnvironment
from .openssl import OpenSSLEnvironment
from .pregenerated import PregeneratedSSLEnvironment
environments = {"none": NoSSLEnvironment,
"openssl": OpenSSLEnvironment,
"pregenerated": PregeneratedSSLEnvironment}
def get_cls(name):
return environments[name]
except KeyError:
raise ValueError("%s is not a valid SSL type." % name)
@ -0,0 +1,17 @@
class NoSSLEnvironment(object):
ssl_enabled = False
def __init__(self, *args, **kwargs):
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
def host_cert_path(self, hosts):
return None, None
def ca_cert_path(self, hosts):
return None
@ -0,0 +1,439 @@
import functools
import os
import random
import shutil
import subprocess
import tempfile
from datetime import datetime, timedelta
from six import iteritems, PY2
# Amount of time beyond the present to consider certificates "expired." This
# allows certificates to be proactively re-generated in the "buffer" period
# prior to their exact expiration time.
CERT_EXPIRY_BUFFER = dict(hours=6)
def _ensure_str(s, encoding):
"""makes sure s is an instance of str, converting with encoding if needed"""
if isinstance(s, str):
return s
if PY2:
return s.encode(encoding)
return s.decode(encoding)
class OpenSSL(object):
def __init__(self, logger, binary, base_path, conf_path, hosts, duration,
"""Context manager for interacting with OpenSSL.
Creates a config file for the duration of the context.
:param logger: stdlib logger or python structured logger
:param binary: path to openssl binary
:param base_path: path to directory for storing certificates
:param conf_path: path for configuration file storing configuration data
:param hosts: list of hosts to include in configuration (or None if not
generating host certificates)
:param duration: Certificate duration in days"""
self.base_path = base_path
self.binary = binary
self.conf_path = conf_path
self.base_conf_path = base_conf_path
self.logger = logger
self.proc = None
self.cmd = []
self.hosts = hosts
self.duration = duration
def __enter__(self):
with open(self.conf_path, "w") as f:
f.write(get_config(self.base_path, self.hosts, self.duration))
return self
def __exit__(self, *args, **kwargs):
def log(self, line):
if hasattr(self.logger, "process_output"):
self.logger.process_output(self.proc.pid if self.proc is not None else None,
line.decode("utf8", "replace"),
command=" ".join(self.cmd))
def __call__(self, cmd, *args, **kwargs):
"""Run a command using OpenSSL in the current context.
:param cmd: The openssl subcommand to run
:param *args: Additional arguments to pass to the command
self.cmd = [self.binary, cmd]
if cmd != "x509":
self.cmd += ["-config", self.conf_path]
self.cmd += list(args)
# Copy the environment, converting to plain strings. Win32 StartProcess
# is picky about all the keys/values being str (on both Py2/3).
env = {}
for k, v in iteritems(os.environ):
env[_ensure_str(k, "utf8")] = _ensure_str(v, "utf8")
if self.base_conf_path is not None:
env["OPENSSL_CONF"] = _ensure_str(self.base_conf_path, "utf-8")
self.proc = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
stdout, stderr = self.proc.communicate()
if self.proc.returncode != 0:
raise subprocess.CalledProcessError(self.proc.returncode, self.cmd,
self.cmd = []
self.proc = None
return stdout
def make_subject(common_name,
args = [("country", "C"),
("state", "ST"),
("locality", "L"),
("organization", "O"),
("organization_unit", "OU"),
("common_name", "CN")]
rv = []
for var, key in args:
value = locals()[var]
if value is not None:
rv.append("/%s=%s" % (key, value.replace("/", "\\/")))
return "".join(rv)
def make_alt_names(hosts):
return ",".join("DNS:%s" % host for host in hosts)
def make_name_constraints(hosts):
return ",".join("permitted;DNS:%s" % host for host in hosts)
def get_config(root_dir, hosts, duration=30):
if hosts is None:
san_line = ""
constraints_line = ""
san_line = "subjectAltName = %s" % make_alt_names(hosts)
constraints_line = "nameConstraints = " + make_name_constraints(hosts)
if os.path.sep == "\\":
# This seems to be needed for the Shining Light OpenSSL on
# Windows, at least.
root_dir = root_dir.replace("\\", "\\\\")
rv = """[ ca ]
default_ca = CA_default
[ CA_default ]
dir = %(root_dir)s
certs = $dir
new_certs_dir = $certs
crl_dir = $dir%(sep)scrl
database = $dir%(sep)sindex.txt
private_key = $dir%(sep)scacert.key
certificate = $dir%(sep)scacert.pem
serial = $dir%(sep)sserial
crldir = $dir%(sep)scrl
crlnumber = $dir%(sep)scrlnumber
crl = $crldir%(sep)scrl.pem
RANDFILE = $dir%(sep)sprivate%(sep)s.rand
x509_extensions = usr_cert
name_opt = ca_default
cert_opt = ca_default
default_days = %(duration)d
default_crl_days = %(duration)d
default_md = sha256
preserve = no
policy = policy_anything
copy_extensions = copy
[ policy_anything ]
countryName = optional
stateOrProvinceName = optional
localityName = optional
organizationName = optional
organizationalUnitName = optional
commonName = supplied
emailAddress = optional
[ req ]
default_bits = 2048
default_keyfile = privkey.pem
distinguished_name = req_distinguished_name
attributes = req_attributes
x509_extensions = v3_ca
# Passwords for private keys if not present they will be prompted for
# input_password = secret
# output_password = secret
string_mask = utf8only
req_extensions = v3_req
[ req_distinguished_name ]
countryName = Country Name (2 letter code)
countryName_default = AU
countryName_min = 2
countryName_max = 2
stateOrProvinceName = State or Province Name (full name)
stateOrProvinceName_default =
localityName = Locality Name (eg, city)
0.organizationName = Organization Name
0.organizationName_default = Web Platform Tests
organizationalUnitName = Organizational Unit Name (eg, section)
#organizationalUnitName_default =
commonName = Common Name (e.g. server FQDN or YOUR name)
commonName_max = 64
emailAddress = Email Address
emailAddress_max = 64
[ req_attributes ]
[ usr_cert ]
[ v3_req ]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
extendedKeyUsage = serverAuth
[ v3_ca ]
basicConstraints = CA:true
keyUsage = keyCertSign
""" % {"root_dir": root_dir,
"san_line": san_line,
"duration": duration,
"constraints_line": constraints_line,
"sep": os.path.sep.replace("\\", "\\\\")}
return rv
class OpenSSLEnvironment(object):
ssl_enabled = True
def __init__(self, logger, openssl_binary="openssl", base_path=None,
password="web-platform-tests", force_regenerate=False,
duration=30, base_conf_path=None):
"""SSL environment that creates a local CA and host certificate using OpenSSL.
By default this will look in base_path for existing certificates that are still
valid and only create new certificates if there aren't any. This behaviour can
be adjusted using the force_regenerate option.
:param logger: a stdlib logging compatible logger or mozlog structured logger
:param openssl_binary: Path to the OpenSSL binary
:param base_path: Path in which certificates will be stored. If None, a temporary
directory will be used and removed when the server shuts down
:param password: Password to use
:param force_regenerate: Always create a new certificate even if one already exists.
self.logger = logger
self.temporary = False
if base_path is None:
base_path = tempfile.mkdtemp()
self.temporary = True
self.base_path = os.path.abspath(base_path)
self.password = password
self.force_regenerate = force_regenerate
self.duration = duration
self.base_conf_path = base_conf_path
self.path = None
self.binary = openssl_binary
self.openssl = None
self._ca_cert_path = None
self._ca_key_path = None
self.host_certificates = {}
def __enter__(self):
if not os.path.exists(self.base_path):
path = functools.partial(os.path.join, self.base_path)
with open(path("index.txt"), "w"):
with open(path("serial"), "w") as f:
serial = "%x" % random.randint(0, 1000000)
if len(serial) % 2:
serial = "0" + serial
self.path = path
return self
def __exit__(self, *args, **kwargs):
if self.temporary:
def _config_openssl(self, hosts):
conf_path = self.path("openssl.cfg")
return OpenSSL(self.logger, self.binary, self.base_path, conf_path, hosts,
self.duration, self.base_conf_path)
def ca_cert_path(self, hosts):
"""Get the path to the CA certificate file, generating a
new one if needed"""
if self._ca_cert_path is None and not self.force_regenerate:
if self._ca_cert_path is None:
return self._ca_cert_path
def _load_ca_cert(self):
key_path = self.path("cacert.key")
cert_path = self.path("cacert.pem")
if self.check_key_cert(key_path, cert_path, None):
self.logger.info("Using existing CA cert")
self._ca_key_path, self._ca_cert_path = key_path, cert_path
def check_key_cert(self, key_path, cert_path, hosts):
"""Check that a key and cert file exist and are valid"""
if not os.path.exists(key_path) or not os.path.exists(cert_path):
return False
with self._config_openssl(hosts) as openssl:
end_date_str = openssl("x509",
"-in", cert_path).split("=", 1)[1].strip()
# Not sure if this works in other locales
end_date = datetime.strptime(end_date_str, "%b %d %H:%M:%S %Y %Z")
time_buffer = timedelta(**CERT_EXPIRY_BUFFER)
# Because `strptime` does not account for time zone offsets, it is
# always in terms of UTC, so the current time should be calculated
# accordingly.
if end_date < datetime.utcnow() + time_buffer:
return False
#TODO: check the key actually signed the cert.
return True
def _generate_ca(self, hosts):
path = self.path
self.logger.info("Generating new CA in %s" % self.base_path)
key_path = path("cacert.key")
req_path = path("careq.pem")
cert_path = path("cacert.pem")
with self._config_openssl(hosts) as openssl:
"-newkey", "rsa:2048",
"-keyout", key_path,
"-out", req_path,
"-subj", make_subject("web-platform-tests"),
"-passout", "pass:%s" % self.password)
"-keyfile", key_path,
"-passin", "pass:%s" % self.password,
"-extensions", "v3_ca",
"-in", req_path,
"-out", cert_path)
self._ca_key_path, self._ca_cert_path = key_path, cert_path
def host_cert_path(self, hosts):
"""Get a tuple of (private key path, certificate path) for a host,
generating new ones if necessary.
hosts must be a list of all hosts to appear on the certificate, with
the primary hostname first."""
hosts = tuple(sorted(hosts, key=lambda x:len(x)))
if hosts not in self.host_certificates:
if not self.force_regenerate:
key_cert = self._load_host_cert(hosts)
key_cert = None
if key_cert is None:
key, cert = self._generate_host_cert(hosts)
key, cert = key_cert
self.host_certificates[hosts] = key, cert
return self.host_certificates[hosts]
def _load_host_cert(self, hosts):
host = hosts[0]
key_path = self.path("%s.key" % host)
cert_path = self.path("%s.pem" % host)
# TODO: check that this cert was signed by the CA cert
if self.check_key_cert(key_path, cert_path, hosts):
self.logger.info("Using existing host cert")
return key_path, cert_path
def _generate_host_cert(self, hosts):
host = hosts[0]
if not self.force_regenerate:
if self._ca_key_path is None:
ca_key_path = self._ca_key_path
assert os.path.exists(ca_key_path)
path = self.path
req_path = path("wpt.req")
cert_path = path("%s.pem" % host)
key_path = path("%s.key" % host)
self.logger.info("Generating new host cert")
with self._config_openssl(hosts) as openssl:
"-newkey", "rsa:2048",
"-keyout", key_path,
"-in", ca_key_path,
"-out", req_path)
"-in", req_path,
"-passin", "pass:%s" % self.password,
"-subj", make_subject(host),
"-out", cert_path)
return key_path, cert_path
@ -0,0 +1,26 @@
class PregeneratedSSLEnvironment(object):
"""SSL environment to use with existing key/certificate files
e.g. when running on a server with a public domain name
ssl_enabled = True
def __init__(self, logger, host_key_path, host_cert_path,
self._ca_cert_path = ca_cert_path
self._host_key_path = host_key_path
self._host_cert_path = host_cert_path
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
def host_cert_path(self, hosts):
"""Return the key and certificate paths for the host"""
return self._host_key_path, self._host_cert_path
def ca_cert_path(self, hosts):
"""Return the certificate path of the CA that signed the
host certificates, or None if that isn't known"""
return self._ca_cert_path
@ -0,0 +1,204 @@
import base64
import json
import os
import six
import threading
import uuid
from multiprocessing.managers import AcquirerProxy, BaseManager, DictProxy
from six import text_type, binary_type
from .utils import isomorphic_encode
class ServerDictManager(BaseManager):
shared_data = {}
def _get_shared():
return ServerDictManager.shared_data
ServerDictManager.register('Lock', threading.Lock, AcquirerProxy)
class ClientDictManager(BaseManager):
class StashServer(object):
def __init__(self, address=None, authkey=None, mp_context=None):
self.address = address
self.authkey = authkey
self.manager = None
self.mp_context = mp_context
def __enter__(self):
self.manager, self.address, self.authkey = start_server(self.address,
store_env_config(self.address, self.authkey)
def __exit__(self, *args, **kwargs):
if self.manager is not None:
def load_env_config():
address, authkey = json.loads(os.environ["WPT_STASH_CONFIG"])
if isinstance(address, list):
address = tuple(address)
address = str(address)
authkey = base64.b64decode(authkey)
return address, authkey
def store_env_config(address, authkey):
authkey = base64.b64encode(authkey)
os.environ["WPT_STASH_CONFIG"] = json.dumps((address, authkey.decode("ascii")))
def start_server(address=None, authkey=None, mp_context=None):
if isinstance(authkey, text_type):
authkey = authkey.encode("ascii")
kwargs = {}
if six.PY3 and mp_context is not None:
kwargs["ctx"] = mp_context
manager = ServerDictManager(address, authkey, **kwargs)
address = manager._address
if isinstance(address, bytes):
address = address.decode("ascii")
return (manager, address, manager._authkey)
class LockWrapper(object):
def __init__(self, lock):
self.lock = lock
def acquire(self):
def release(self):
def __enter__(self):
def __exit__(self, *args, **kwargs):
#TODO: Consider expiring values after some fixed time for long-running
class Stash(object):
"""Key-value store for persisting data across HTTP/S and WS/S requests.
This data store is specifically designed for persisting data across server
requests. The synchronization is achieved by using the BaseManager from
the multiprocessing module so different processes can acccess the same data.
Stash can be used interchangeably between HTTP, HTTPS, WS and WSS servers.
A thing to note about WS/S servers is that they require additional steps in
the handlers for accessing the same underlying shared data in the Stash.
This can usually be achieved by using load_env_config(). When using Stash
interchangeably between HTTP/S and WS/S request, the path part of the key
should be expliclitly specified if accessing the same key/value subset.
The store has several unusual properties. Keys are of the form (path,
uuid), where path is, by default, the path in the HTTP request and
uuid is a unique id. In addition, the store is write-once, read-once,
i.e. the value associated with a particular key cannot be changed once
written and the read operation (called "take") is destructive. Taken together,
these properties make it difficult for data to accidentally leak
between different resources or different requests for the same
_proxy = None
lock = None
_initializing = threading.Lock()
def __init__(self, default_path, address=None, authkey=None):
self.default_path = default_path
self._get_proxy(address, authkey)
self.data = Stash._proxy
def _get_proxy(self, address=None, authkey=None):
if address is None and authkey is None:
Stash._proxy = {}
Stash.lock = threading.Lock()
# Initializing the proxy involves connecting to the remote process and
# retrieving two proxied objects. This process is not inherently
# atomic, so a lock must be used to make it so. Atomicity ensures that
# only one thread attempts to initialize the connection and that any
# threads running in parallel correctly wait for initialization to be
# fully complete.
with Stash._initializing:
if Stash.lock:
manager = ClientDictManager(address, authkey)
Stash._proxy = manager.get_dict()
Stash.lock = LockWrapper(manager.Lock())
def _wrap_key(self, key, path):
if path is None:
path = self.default_path
# This key format is required to support using the path. Since the data
# passed into the stash can be a DictProxy which wouldn't detect
# changes when writing to a subdict.
if isinstance(key, binary_type):
# UUIDs are within the ASCII charset.
key = key.decode('ascii')
return (isomorphic_encode(path), uuid.UUID(key).bytes)
def put(self, key, value, path=None):
"""Place a value in the shared stash.
:param key: A UUID to use as the data's key.
:param value: The data to store. This can be any python object.
:param path: The path that has access to read the data (by default
the current request path)"""
if value is None:
raise ValueError("SharedStash value may not be set to None")
internal_key = self._wrap_key(key, path)
if internal_key in self.data:
raise StashError("Tried to overwrite existing shared stash value "
"for key %s (old value was %s, new value is %s)" %
(internal_key, self.data[internal_key], value))
self.data[internal_key] = value
def take(self, key, path=None):
"""Remove a value from the shared stash and return it.
:param key: A UUID to use as the data's key.
:param path: The path that has access to read the data (by default
the current request path)"""
internal_key = self._wrap_key(key, path)
value = self.data.get(internal_key, None)
if value is not None:
except KeyError:
# Silently continue when pop error occurs.
return value
class StashError(Exception):
@ -0,0 +1,165 @@
import socket
import sys
from six import binary_type, text_type
def isomorphic_decode(s):
"""Decodes a binary string into a text string using iso-8859-1.
Returns `unicode` in Python 2 and `str` in Python 3. The function is a
no-op if the argument already has a text type. iso-8859-1 is chosen because
it is an 8-bit encoding whose code points range from 0x0 to 0xFF and the
values are the same as the binary representations, so any binary string can
be decoded into and encoded from iso-8859-1 without any errors or data
loss. Python 3 also uses iso-8859-1 (or latin-1) extensively in http:
if isinstance(s, text_type):
return s
if isinstance(s, binary_type):
return s.decode("iso-8859-1")
raise TypeError("Unexpected value (expecting string-like): %r" % s)
def isomorphic_encode(s):
"""Encodes a text-type string into binary data using iso-8859-1.
Returns `str` in Python 2 and `bytes` in Python 3. The function is a no-op
if the argument already has a binary type. This is the counterpart of
if isinstance(s, binary_type):
return s
if isinstance(s, text_type):
return s.encode("iso-8859-1")
raise TypeError("Unexpected value (expecting string-like): %r" % s)
def invert_dict(dict):
rv = {}
for key, values in dict.items():
for value in values:
if value in rv:
raise ValueError
rv[value] = key
return rv
class HTTPException(Exception):
def __init__(self, code, message=""):
self.code = code
self.message = message
def _open_socket(host, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if port != 0:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
return sock
def is_bad_port(port):
Bad port as per https://fetch.spec.whatwg.org/#port-blocking
return port in [
1, # tcpmux
7, # echo
9, # discard
11, # systat
13, # daytime
15, # netstat
17, # qotd
19, # chargen
20, # ftp-data
21, # ftp
22, # ssh
23, # telnet
25, # smtp
37, # time
42, # name
43, # nicname
53, # domain
77, # priv-rjs
79, # finger
87, # ttylink
95, # supdup
101, # hostriame
102, # iso-tsap
103, # gppitnp
104, # acr-nema
109, # pop2
110, # pop3
111, # sunrpc
113, # auth
115, # sftp
117, # uucp-path
119, # nntp
123, # ntp
135, # loc-srv / epmap
139, # netbios
143, # imap2
179, # bgp
389, # ldap
427, # afp (alternate)
465, # smtp (alternate)
512, # print / exec
513, # login
514, # shell
515, # printer
526, # tempo
530, # courier
531, # chat
532, # netnews
540, # uucp
548, # afp
554, # rtsp
556, # remotefs
563, # nntp+ssl
587, # smtp (outgoing)
601, # syslog-conn
636, # ldap+ssl
993, # ldap+ssl
995, # pop3+ssl
1720, # h323hostcall
1723, # pptp
2049, # nfs
3659, # apple-sasl
4045, # lockd
5060, # sip
5061, # sips
6000, # x11
6665, # irc (alternate)
6666, # irc (alternate)
6667, # irc (default)
6668, # irc (alternate)
6669, # irc (alternate)
6697, # irc+tls
def get_port(host=''):
host = host or ''
port = 0
while True:
free_socket = _open_socket(host, 0)
port = free_socket.getsockname()[1]
if not is_bad_port(port):
return port
def http2_compatible():
# Currently, the HTTP/2.0 server is only working in python 2.7.10+ or 3.6+ and OpenSSL 1.0.2+
import ssl
py_v = sys.version_info
return (((py_v[0] == 2 and py_v[1] == 7 and py_v[2] >= 10) or (py_v[0] == 3 and py_v[1] >= 6)) and
(ssl_v[0] == 1 and (ssl_v[1] == 1 or (ssl_v[1] == 0 and ssl_v[2] >= 2))))
@ -0,0 +1,33 @@
#!/usr/bin/env python
import argparse
import os
import server
def abs_path(path):
return os.path.abspath(path)
def parse_args():
parser = argparse.ArgumentParser(description="HTTP server designed for extreme flexibility "
"required in testing situations.")
parser.add_argument("document_root", action="store", type=abs_path,
help="Root directory to serve files from")
parser.add_argument("--port", "-p", dest="port", action="store",
type=int, default=8000,
help="Port number to run server on")
parser.add_argument("--host", "-H", dest="host", action="store",
type=str, default="",
help="Host to run server on")
return parser.parse_args()
def main():
args = parse_args()
httpd = server.WebTestHttpd(host=args.host, port=args.port,
use_ssl=False, certificate=None,
if __name__ == "__main__":
@ -0,0 +1,247 @@
"""This file provides the opening handshake processor for the Bootstrapping
WebSockets with HTTP/2 protocol (RFC 8441).
from __future__ import absolute_import
from mod_pywebsocket import common
from mod_pywebsocket.stream import Stream
from mod_pywebsocket.stream import StreamOptions
from mod_pywebsocket import util
from six.moves import map
from six.moves import range
# TODO: We are using "private" methods of pywebsocket. We might want to
# refactor pywebsocket to expose those methods publicly. Also, _get_origin
# _check_version _set_protocol _parse_extensions and a large part of do_handshake
# are identical with hybi handshake. We need some refactoring to get remove that
# code duplication.
from mod_pywebsocket.extensions import get_extension_processor
from mod_pywebsocket.handshake._base import get_mandatory_header
from mod_pywebsocket.handshake._base import HandshakeException
from mod_pywebsocket.handshake._base import parse_token_list
from mod_pywebsocket.handshake._base import validate_mandatory_header
from mod_pywebsocket.handshake._base import validate_subprotocol
from mod_pywebsocket.handshake._base import VersionException
# Defining aliases for values used frequently.
def check_connect_method(request):
if request.method != u'CONNECT':
raise HandshakeException('Method is not CONNECT: %r' % request.method)
class WsH2Handshaker(object):
def __init__(self, request, dispatcher):
"""Opening handshake processor for the WebSocket protocol (RFC 6455).
:param request: mod_python request.
:param dispatcher: Dispatcher (dispatch.Dispatcher).
WsH2Handshaker will add attributes such as ws_resource during handshake.
self._logger = util.get_class_logger(self)
self._request = request
self._dispatcher = dispatcher
def _get_extension_processors_requested(self):
processors = []
if self._request.ws_requested_extensions is not None:
for extension_request in self._request.ws_requested_extensions:
processor = get_extension_processor(extension_request)
# Unknown extension requests are just ignored.
if processor is not None:
return processors
def do_handshake(self):
self._request.ws_close_code = None
self._request.ws_close_reason = None
# Parsing.
validate_mandatory_header(self._request, ':protocol', 'websocket')
self._request.ws_resource = self._request.uri
get_mandatory_header(self._request, 'authority')
self._request.ws_version = self._check_version()
# Setup extension processors.
self._request.ws_extension_processors = self._get_extension_processors_requested()
# List of extra headers. The extra handshake handler may add header
# data as name/value pairs to this list and pywebsocket appends
# them to the WebSocket handshake.
self._request.extra_headers = []
# Extra handshake handler may modify/remove processors.
processors = [
for processor in self._request.ws_extension_processors
if processor is not None
# Ask each processor if there are extensions on the request which
# cannot co-exist. When processor decided other processors cannot
# co-exist with it, the processor marks them (or itself) as
# "inactive". The first extension processor has the right to
# make the final call.
for processor in reversed(processors):
if processor.is_active():
processors = [
processor for processor in processors if processor.is_active()
accepted_extensions = []
stream_options = StreamOptions()
for index, processor in enumerate(processors):
if not processor.is_active():
extension_response = processor.get_extension_response()
if extension_response is None:
# Rejected.
# Inactivate all of the following compression extensions.
for j in range(index + 1, len(processors)):
if len(accepted_extensions) > 0:
self._request.ws_extensions = accepted_extensions
'Extensions accepted: %r',
self._request.ws_extensions = None
self._request.ws_stream = self._create_stream(stream_options)
if self._request.ws_requested_protocols is not None:
if self._request.ws_protocol is None:
raise HandshakeException(
'do_extra_handshake must choose one subprotocol from '
'ws_requested_protocols and set it to ws_protocol')
self._logger.debug('Subprotocol accepted: %r',
if self._request.ws_protocol is not None:
raise HandshakeException(
'ws_protocol must be None when the client didn\'t '
'request any subprotocol')
except HandshakeException as e:
if not e.status:
# Fallback to 400 bad request by default.
e.status = common.HTTP_STATUS_BAD_REQUEST
raise e
def _get_origin(self):
origin = self._request.headers_in.get('origin')
if origin is None:
self._logger.debug('Client request does not have origin header')
self._request.ws_origin = origin
def _check_version(self):
sec_websocket_version_header = 'sec-websocket-version'
version = get_mandatory_header(self._request, sec_websocket_version_header)
if version.find(',') >= 0:
raise HandshakeException(
'Multiple versions (%r) are not allowed for header %s' %
(version, sec_websocket_version_header),
raise VersionException('Unsupported version %r for header %s' %
(version, sec_websocket_version_header),
supported_versions=', '.join(
def _set_protocol(self):
self._request.ws_protocol = None
protocol_header = self._request.headers_in.get('sec-websocket-protocol')
if protocol_header is None:
self._request.ws_requested_protocols = None
self._request.ws_requested_protocols = parse_token_list(
self._logger.debug('Subprotocols requested: %r',
def _parse_extensions(self):
extensions_header = self._request.headers_in.get('sec-websocket-extensions')
if not extensions_header:
self._request.ws_requested_extensions = None
self._request.ws_requested_extensions = common.parse_extensions(
except common.ExtensionParsingException as e:
raise HandshakeException(
'Failed to parse sec-websocket-extensions header: %r' % e)
'Extensions requested: %r',
def _create_stream(self, stream_options):
return Stream(self._request, stream_options)
def _prepare_handshake_response(self):
self._request.status = 200
self._request.headers_out['upgrade'] = common.WEBSOCKET_UPGRADE_TYPE
self._request.headers_out['connection'] = common.UPGRADE_CONNECTION_TYPE
if self._request.ws_protocol is not None:
self._request.headers_out['sec-websocket-protocol'] = self._request.ws_protocol
if (self._request.ws_extensions is not None and
len(self._request.ws_extensions) != 0):
self._request.headers_out['sec-websocket-extensions'] = common.format_extensions(self._request.ws_extensions)
# Headers not specific for WebSocket
for name, value in self._request.extra_headers:
self._request.headers_out[name] = value
@ -7,6 +7,7 @@ black:
- python/mozbuild/mozbuild/fork_interpose.py
- python/mozbuild/mozbuild/test/frontend/data/reader-error-syntax/moz.build
- testing/mozharness/configs/test/test_malformed.py
- testing/web-platform/mozilla/tests/tools/wptserve_py2
- testing/web-platform/tests
- build
@ -166,6 +166,7 @@ file-whitespace:
- testing/web-platform/tests/tools/wptrunner/wptrunner/tests/test_update.py
- testing/web-platform/tests/tools/lint/tests/dummy/broken.html
- testing/web-platform/tests/tools/lint/tests/dummy/broken_ignored.html
- testing/web-platform/mozilla/tests/tools/wptserve_py2
- toolkit/components/telemetry/build_scripts/setup.py
- toolkit/components/telemetry/tests/marionette/mach_commands.py
- toolkit/content/tests/chrome
@ -17,6 +17,7 @@ py3:
- testing/mozharness
- testing/tps
- testing/web-platform/tests
- testing/web-platform/mozilla/tests/tools/wptserve_py2/
- toolkit
- xpcom/idl-parser
extensions: ['py']
Ссылка в новой задаче