Bug 1695263 - Vendor in a copy of wptserve that's still Python 2 compatible, r=marionette-reviewers,whimboo

Upstream wptserve just switched to Python 3 only. That's fine for
web-platform-tests, but it turns out that some marionette harness
tests are also using wptserve and are still on Python 2.

Since fixing marionette harness turns out to be non-trivial and this
blocks other wpt work, this patch does the following:

* Temporarily vendors the last wptserve revision that works with
  Python 2 in to testing/web-platform/mozilla/tests/tools/wptserve_py2

* Configures the mach virtualenv to use that copy for Python 2 modules
  only.

* Configures the test packaging system to also put that copy in the
  common tests zip. Requirements files are updated to use either the
  Python 2 version or the Pyhton 3 version as required.

Differential Revision: https://phabricator.services.mozilla.com/D106764
This commit is contained in:
James Graham 2021-03-03 10:03:05 +00:00
Родитель 1a97ef9671
Коммит 7814585840
98 изменённых файлов: 9109 добавлений и 4 удалений

Просмотреть файл

@ -73,6 +73,7 @@ exclude =
testing/marionette/harness/marionette_harness/tests,
testing/mochitest/pywebsocket3,
testing/mozharness/configs/test/test_malformed.py,
testing/web-platform/mozilla/tests/tools/wptserve_py2,
testing/web-platform/tests,
tools/lint/test/files,
tools/infer/test/*.configure,

Просмотреть файл

@ -34,7 +34,8 @@ mozilla.pth:testing/web-platform/tests/tools/third_party/html5lib
mozilla.pth:testing/web-platform/tests/tools/third_party/hyperframe
mozilla.pth:testing/web-platform/tests/tools/third_party/pywebsocket3
mozilla.pth:testing/web-platform/tests/tools/third_party/webencodings
mozilla.pth:testing/web-platform/tests/tools/wptserve
python3:mozilla.pth:testing/web-platform/tests/tools/wptserve
python2:mozilla.pth:testing/web-platform/mozilla/tests/tools/wptserve_py2
mozilla.pth:testing/web-platform/tests/tools/wptrunner
mozilla.pth:testing/xpcshell
mozilla.pth:third_party/python/appdirs

Просмотреть файл

@ -170,6 +170,12 @@ ARCHIVE_FILES = {
"pattern": "**",
"dest": "tps/tests",
},
{
"source": buildconfig.topsrcdir,
"base": "testing/web-platform/mozilla/tests/tools/wptserve_py2",
"pattern": "**",
"dest": "tools/wptserve_py2",
},
{
"source": buildconfig.topsrcdir,
"base": "testing/web-platform/tests/tools/wptserve",
@ -577,7 +583,6 @@ ARCHIVE_FILES = {
"base": "testing",
"pattern": "web-platform/tests/**",
"ignore": [
"web-platform/tests/tools/wptserve",
"web-platform/tests/tools/wpt_third_party",
],
},

Просмотреть файл

@ -1,6 +1,7 @@
-r mozbase_requirements.txt
../tools/wptserve
../tools/wptserve_py2 ; python_version < '3'
../tools/wptserve ; python_version >= '3'
../tools/wpt_third_party/certifi
../tools/wpt_third_party/enum ; python_version < '3'
../tools/wpt_third_party/h2

Просмотреть файл

@ -2,7 +2,8 @@
-r mozbase_source_requirements.txt
../web-platform/tests/tools/wptserve
../web-platform/tests/tools/wptserve ; python_version >= '3'
../web-platform/mozilla/tests/tools/wptserve_py2 ; python_version < '3'
../web-platform/tests/tools/third_party/certifi
../web-platform/tests/tools/third_party/enum ; python_version < '3'
../web-platform/tests/tools/third_party/h2

Просмотреть файл

@ -0,0 +1,11 @@
[run]
branch = True
parallel = True
omit =
*/site-packages/*
*/lib_pypy/*
[paths]
wptserve =
wptserve
.tox/**/site-packages/wptserve

Просмотреть файл

@ -0,0 +1,40 @@
*.py[cod]
*~
\#*
docs/_build/
# C extensions
*.so
# Packages
*.egg
*.egg-info
dist
build
eggs
parts
bin
var
sdist
develop-eggs
.installed.cfg
lib
lib64
# Installer logs
pip-log.txt
# Unit test / coverage reports
.coverage
.tox
nosetests.xml
tests/functional/html/*
# Translations
*.mo
# Mr Developer
.mr.developer.cfg
.project
.pydevproject

Просмотреть файл

@ -0,0 +1,11 @@
# The 3-Clause BSD License
Copyright 2019 web-platform-tests contributors
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

Просмотреть файл

@ -0,0 +1 @@
include README.md

Просмотреть файл

@ -0,0 +1,8 @@
Python 2 compatible version of wptserve.
This should be removed as soon as Python 2 is no longer required by
any code that uses wptserve (notably marionette-harness based tests).
When removing it reverting the commit that added it should ensure that
we return to the previous configuration where wptserve is used from
the wpt import.

Просмотреть файл

@ -0,0 +1,153 @@
# Makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
PAPER =
BUILDDIR = _build
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
# the i18n builder cannot share the environment and doctrees with the others
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " html to make standalone HTML files"
@echo " dirhtml to make HTML files named index.html in directories"
@echo " singlehtml to make a single large HTML file"
@echo " pickle to make pickle files"
@echo " json to make JSON files"
@echo " htmlhelp to make HTML files and a HTML help project"
@echo " qthelp to make HTML files and a qthelp project"
@echo " devhelp to make HTML files and a Devhelp project"
@echo " epub to make an epub"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex"
@echo " text to make text files"
@echo " man to make manual pages"
@echo " texinfo to make Texinfo files"
@echo " info to make Texinfo files and run them through makeinfo"
@echo " gettext to make PO message catalogs"
@echo " changes to make an overview of all changed/added/deprecated items"
@echo " linkcheck to check all external links for integrity"
@echo " doctest to run all doctests embedded in the documentation (if enabled)"
clean:
-rm -rf $(BUILDDIR)/*
html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
@echo "Build finished; now you can process the pickle files."
json:
$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@echo
@echo "Build finished; now you can process the JSON files."
htmlhelp:
$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
@echo
@echo "Build finished; now you can run HTML Help Workshop with the" \
".hhp project file in $(BUILDDIR)/htmlhelp."
qthelp:
$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
@echo
@echo "Build finished; now you can run "qcollectiongenerator" with the" \
".qhcp project file in $(BUILDDIR)/qthelp, like this:"
@echo "# qcollectiongenerator $(BUILDDIR)/qthelp/wptserve.qhcp"
@echo "To view the help file:"
@echo "# assistant -collectionFile $(BUILDDIR)/qthelp/wptserve.qhc"
devhelp:
$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
@echo
@echo "Build finished."
@echo "To view the help file:"
@echo "# mkdir -p $$HOME/.local/share/devhelp/wptserve"
@echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/wptserve"
@echo "# devhelp"
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)."
latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
text:
$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
@echo
@echo "Build finished. The text files are in $(BUILDDIR)/text."
man:
$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
@echo
@echo "Build finished. The manual pages are in $(BUILDDIR)/man."
texinfo:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo
@echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
@echo "Run \`make' in that directory to run these through makeinfo" \
"(use \`make info' here to do that automatically)."
info:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo "Running Texinfo files through makeinfo..."
make -C $(BUILDDIR)/texinfo info
@echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
gettext:
$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
@echo
@echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
changes:
$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
@echo
@echo "The overview file is in $(BUILDDIR)/changes."
linkcheck:
$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
@echo
@echo "Link check complete; look for any errors in the above output " \
"or in $(BUILDDIR)/linkcheck/output.txt."
doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."

Просмотреть файл

@ -0,0 +1,243 @@
# -*- coding: utf-8 -*-
#
# wptserve documentation build configuration file, created by
# sphinx-quickstart on Wed Aug 14 17:23:24 2013.
#
# This file is execfile()d with the current directory set to its containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
import sys, os
sys.path.insert(0, os.path.abspath(".."))
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = '.rst'
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u'wptserve'
copyright = u'2013, Mozilla Foundation and other wptserve contributers'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.1'
# The full version, including alpha/beta/rc tags.
release = '0.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
# The reST default role (used for this markup: `text`) to use for all documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'default'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
#html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
#html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
#html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
#html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
#html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# Output file base name for HTML help builder.
htmlhelp_basename = 'wptservedoc'
# -- Options for LaTeX output --------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
('index', 'wptserve.tex', u'wptserve Documentation',
u'James Graham', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
#latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
#latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# -- Options for manual page output --------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'wptserve', u'wptserve Documentation',
[u'James Graham'], 1)
]
# If true, show URL addresses after external links.
#man_show_urls = False
# -- Options for Texinfo output ------------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'wptserve', u'wptserve Documentation',
u'James Graham', 'wptserve', 'One line description of project.',
'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'

Просмотреть файл

@ -0,0 +1,108 @@
Handlers
========
Handlers are functions that have the general signature::
handler(request, response)
It is expected that the handler will use information from
the request (e.g. the path) either to populate the response
object with the data to send, or to directly write to the
output stream via the ResponseWriter instance associated with
the request. If a handler writes to the output stream then the
server will not attempt additional writes, i.e. the choice to write
directly in the handler or not is all-or-nothing.
A number of general-purpose handler functions are provided by default:
.. _handlers.Python:
Python Handlers
---------------
Python handlers are functions which provide a higher-level API over
manually updating the response object, by causing the return value of
the function to provide (part of) the response. There are four
possible sets of values that may be returned::
((status_code, reason), headers, content)
(status_code, headers, content)
(headers, content)
content
Here `status_code` is an integer status code, `headers` is a list of (field
name, value) pairs, and `content` is a string or an iterable returning strings.
Such a function may also update the response manually. For example one may use
`response.headers.set` to set a response header, and only return the content.
One may even use this kind of handler, but manipulate the output socket
directly, in which case the return value of the function, and the properties of
the response object, will be ignored.
The most common way to make a user function into a python handler is
to use the provided `wptserve.handlers.handler` decorator::
from wptserve.handlers import handler
@handler
def test(request, response):
return [("X-Test": "PASS"), ("Content-Type", "text/plain")], "test"
#Later, assuming we have a Router object called 'router'
router.register("GET", "/test", test)
JSON Handlers
-------------
This is a specialisation of the python handler type specifically
designed to facilitate providing JSON responses. The API is largely
the same as for a normal python handler, but the `content` part of the
return value is JSON encoded, and a default Content-Type header of
`application/json` is added. Again this handler is usually used as a
decorator::
from wptserve.handlers import json_handler
@json_handler
def test(request, response):
return {"test": "PASS"}
Python File Handlers
--------------------
Python file handlers are Python files which the server executes in response to
requests made to the corresponding URL. This is hooked up to a route like
``("*", "*.py", python_file_handler)``, meaning that any .py file will be
treated as a handler file (note that this makes it easy to write unsafe
handlers, particularly when running the server in a web-exposed setting).
The Python files must define a single function `main` with the signature::
main(request, response)
This function then behaves just like those described in
:ref:`handlers.Python` above.
asis Handlers
-------------
These are used to serve files as literal byte streams including the
HTTP status line, headers and body. In the default configuration this
handler is invoked for all files with a .asis extension.
File Handlers
-------------
File handlers are used to serve static files. By default the content
type of these files is set by examining the file extension. However
this can be overridden, or additional headers supplied, by providing a
file with the same name as the file being served but an additional
.headers suffix, i.e. test.html has its headers set from
test.html.headers. The format of the .headers file is plaintext, with
each line containing::
Header-Name: header_value
In addition headers can be set for a whole directory of files (but not
subdirectories), using a file called `__dir__.headers`.

Просмотреть файл

@ -0,0 +1,27 @@
.. wptserve documentation master file, created by
sphinx-quickstart on Wed Aug 14 17:23:24 2013.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
wptserve: Web Platform Test Server
==================================
A python-based HTTP server specifically targeted at being used for
testing the web platform. This means that extreme flexibility —
including the possibility of HTTP non-conformance — in the response is
supported.
Contents:
.. toctree::
:maxdepth: 2
introduction
server
router
request
response
stash
handlers
pipes

Просмотреть файл

@ -0,0 +1,51 @@
Introduction
============
wptserve has been designed with the specific goal of making a server
that is suitable for writing tests for the web platform. This means
that it cannot use common abstractions over HTTP such as WSGI, since
these assume that the goal is to generate a well-formed HTTP
response. Testcases, however, often require precise control of the
exact bytes sent over the wire and their timing. The full list of
design goals for the server are:
* Suitable to run on individual test machines and over the public internet.
* Support plain TCP and SSL servers.
* Serve static files with the minimum of configuration.
* Allow headers to be overwritten on a per-file and per-directory
basis.
* Full customisation of headers sent (e.g. altering or omitting
"mandatory" headers).
* Simple per-client state.
* Complex logic in tests, up to precise control over the individual
bytes sent and the timing of sending them.
Request Handling
----------------
At the high level, the design of the server is based around similar
concepts to those found in common web frameworks like Django, Pyramid
or Flask. In particular the lifecycle of a typical request will be
familiar to users of these systems. Incoming requests are parsed and a
:doc:`Request <request>` object is constructed. This object is passed
to a :ref:`Router <router.Interface>` instance, which is
responsible for mapping the request method and path to a handler
function. This handler is passed two arguments; the request object and
a :doc:`Response <response>` object. In cases where only simple
responses are required, the handler function may fill in the
properties of the response object and the server will take care of
constructing the response. However each Response also contains a
:ref:`ResponseWriter <response.Interface>` which can be
used to directly control the TCP socket.
By default there are several built-in handler functions that provide a
higher level API than direct manipulation of the Response
object. These are documented in :doc:`handlers`.

Просмотреть файл

@ -0,0 +1,190 @@
@ECHO OFF
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set BUILDDIR=_build
set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
set I18NSPHINXOPTS=%SPHINXOPTS% .
if NOT "%PAPER%" == "" (
set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
)
if "%1" == "" goto help
if "%1" == "help" (
:help
echo.Please use `make ^<target^>` where ^<target^> is one of
echo. html to make standalone HTML files
echo. dirhtml to make HTML files named index.html in directories
echo. singlehtml to make a single large HTML file
echo. pickle to make pickle files
echo. json to make JSON files
echo. htmlhelp to make HTML files and a HTML help project
echo. qthelp to make HTML files and a qthelp project
echo. devhelp to make HTML files and a Devhelp project
echo. epub to make an epub
echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
echo. text to make text files
echo. man to make manual pages
echo. texinfo to make Texinfo files
echo. gettext to make PO message catalogs
echo. changes to make an overview over all changed/added/deprecated items
echo. linkcheck to check all external links for integrity
echo. doctest to run all doctests embedded in the documentation if enabled
goto end
)
if "%1" == "clean" (
for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
del /q /s %BUILDDIR%\*
goto end
)
if "%1" == "html" (
%SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/html.
goto end
)
if "%1" == "dirhtml" (
%SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
goto end
)
if "%1" == "singlehtml" (
%SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
goto end
)
if "%1" == "pickle" (
%SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the pickle files.
goto end
)
if "%1" == "json" (
%SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the JSON files.
goto end
)
if "%1" == "htmlhelp" (
%SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can run HTML Help Workshop with the ^
.hhp project file in %BUILDDIR%/htmlhelp.
goto end
)
if "%1" == "qthelp" (
%SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can run "qcollectiongenerator" with the ^
.qhcp project file in %BUILDDIR%/qthelp, like this:
echo.^> qcollectiongenerator %BUILDDIR%\qthelp\wptserve.qhcp
echo.To view the help file:
echo.^> assistant -collectionFile %BUILDDIR%\qthelp\wptserve.ghc
goto end
)
if "%1" == "devhelp" (
%SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished.
goto end
)
if "%1" == "epub" (
%SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The epub file is in %BUILDDIR%/epub.
goto end
)
if "%1" == "latex" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
if errorlevel 1 exit /b 1
echo.
echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
goto end
)
if "%1" == "text" (
%SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The text files are in %BUILDDIR%/text.
goto end
)
if "%1" == "man" (
%SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The manual pages are in %BUILDDIR%/man.
goto end
)
if "%1" == "texinfo" (
%SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
goto end
)
if "%1" == "gettext" (
%SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
goto end
)
if "%1" == "changes" (
%SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
if errorlevel 1 exit /b 1
echo.
echo.The overview file is in %BUILDDIR%/changes.
goto end
)
if "%1" == "linkcheck" (
%SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
if errorlevel 1 exit /b 1
echo.
echo.Link check complete; look for any errors in the above output ^
or in %BUILDDIR%/linkcheck/output.txt.
goto end
)
if "%1" == "doctest" (
%SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
if errorlevel 1 exit /b 1
echo.
echo.Testing of doctests in the sources finished, look at the ^
results in %BUILDDIR%/doctest/output.txt.
goto end
)
:end

Просмотреть файл

@ -0,0 +1,8 @@
Pipes
======
:mod:`Interface <wptserve.pipes>`
---------------------------------
.. automodule:: wptserve.pipes
:members:

Просмотреть файл

@ -0,0 +1,10 @@
Request
=======
Request object.
:mod:`Interface <wptserve.request>`
-----------------------------------
.. automodule:: wptserve.request
:members:

Просмотреть файл

@ -0,0 +1,41 @@
Response
========
Response object. This object is used to control the response that will
be sent to the HTTP client. A handler function will take the response
object and fill in various parts of the response. For example, a plain
text response with the body 'Some example content' could be produced as::
def handler(request, response):
response.headers.set("Content-Type", "text/plain")
response.content = "Some example content"
The response object also gives access to a ResponseWriter, which
allows direct access to the response socket. For example, one could
write a similar response but with more explicit control as follows::
import time
def handler(request, response):
response.add_required_headers = False # Don't implicitly add HTTP headers
response.writer.write_status(200)
response.writer.write_header("Content-Type", "text/plain")
response.writer.write_header("Content-Length", len("Some example content"))
response.writer.end_headers()
response.writer.write("Some ")
time.sleep(1)
response.writer.write("example content")
Note that when writing the response directly like this it is always
necessary to either set the Content-Length header or set
`response.close_connection = True`. Without one of these, the client
will not be able to determine where the response body ends and will
continue to load indefinitely.
.. _response.Interface:
:mod:`Interface <wptserve.response>`
------------------------------------
.. automodule:: wptserve.response
:members:

Просмотреть файл

@ -0,0 +1,78 @@
Router
======
The router is used to match incoming requests to request handler
functions. Typically users don't interact with the router directly,
but instead send a list of routes to register when starting the
server. However it is also possible to add routes after starting the
server by calling the `register` method on the server's `router`
property.
Routes are represented by a three item tuple::
(methods, path_match, handler)
`methods` is either a string or a list of strings indicating the HTTP
methods to match. In cases where all methods should match there is a
special sentinel value `any_method` provided as a property of the
`router` module that can be used.
`path_match` is an expression that will be evaluated against the
request path to decide if the handler should match. These expressions
follow a custom syntax intended to make matching URLs straightforward
and, in particular, to be easier to use than raw regexp for URL
matching. There are three possible components of a match expression:
* Literals. These match any character. The special characters \*, \{
and \} must be escaped by prefixing them with a \\.
* Match groups. These match any character other than / and save the
result as a named group. They are delimited by curly braces; for
example::
{abc}
would create a match group with the name `abc`.
* Stars. These are denoted with a `*` and match any character
including /. There can be at most one star
per pattern and it must follow any match groups.
Path expressions always match the entire request path and a leading /
in the expression is implied even if it is not explicitly
provided. This means that `/foo` and `foo` are equivalent.
For example, the following pattern matches all requests for resources with the
extension `.py`::
*.py
The following expression matches anything directly under `/resources`
with a `.html` extension, and places the "filename" in the `name`
group::
/resources/{name}.html
The groups, including anything that matches a `*` are available in the
request object through the `route_match` property. This is a
dictionary mapping the group names, and any match for `*` to the
matching part of the route. For example, given a route::
/api/{sub_api}/*
and the request path `/api/test/html/test.html`, `route_match` would
be::
{"sub_api": "html", "*": "html/test.html"}
`handler` is a function taking a request and a response object that is
responsible for constructing the response to the HTTP request. See
:doc:`handlers` for more details on handler functions.
.. _router.Interface:
:mod:`Interface <wptserve.router>`
----------------------------------
.. automodule:: wptserve.router
:members:

Просмотреть файл

@ -0,0 +1,20 @@
Server
======
Basic server classes and router.
The following example creates a server that serves static files from
the `files` subdirectory of the current directory and causes it to
run on port 8080 until it is killed::
from wptserve import server, handlers
httpd = server.WebTestHttpd(port=8080, doc_root="./files/",
routes=[("GET", "*", handlers.file_handler)])
httpd.start(block=True)
:mod:`Interface <wptserve.server>`
----------------------------------
.. automodule:: wptserve.server
:members:

Просмотреть файл

@ -0,0 +1,31 @@
Stash
=====
Object for storing cross-request state. This is unusual in that keys
must be UUIDs, in order to prevent different clients setting the same
key, and values are write-once, read-once to minimize the chances of
state persisting indefinitely. The stash defines two operations;
`put`, to add state and `take` to remove state. Furthermore, the view
of the stash is path-specific; by default a request will only see the
part of the stash corresponding to its own path.
A typical example of using a stash to store state might be::
@handler
def handler(request, response):
# We assume this is a string representing a UUID
key = request.GET.first("id")
if request.method == "POST":
request.server.stash.put(key, "Some sample value")
return "Added value to stash"
else:
value = request.server.stash.take(key)
assert request.server.stash.take(key) is None
return key
:mod:`Interface <wptserve.stash>`
---------------------------------
.. automodule:: wptserve.stash
:members:

Просмотреть файл

@ -0,0 +1,23 @@
from setuptools import setup
PACKAGE_VERSION = '3.0'
deps = ["six>=1.13.0", "h2>=3.0.1"]
setup(name='wptserve',
version=PACKAGE_VERSION,
description="Python webserver intended for in web browser testing",
long_description=open("README.md").read(),
# Get strings from http://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers=["Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: BSD License",
"Topic :: Internet :: WWW/HTTP :: HTTP Servers"],
keywords='',
author='James Graham',
author_email='james@hoppipolla.co.uk',
url='http://wptserve.readthedocs.org/',
license='BSD',
packages=['wptserve', 'wptserve.sslutils'],
include_package_data=True,
zip_safe=False,
install_requires=deps
)

Просмотреть файл

Просмотреть файл

@ -0,0 +1,160 @@
from __future__ import print_function
import base64
import logging
import os
import pytest
import unittest
from six.moves.urllib.parse import urlencode, urlunsplit
from six.moves.urllib.request import Request as BaseRequest
from six.moves.urllib.request import urlopen
from six import binary_type, iteritems, PY3
from hyper import HTTP20Connection, tls
import ssl
from localpaths import repo_root
wptserve = pytest.importorskip("wptserve")
logging.basicConfig()
wptserve.logger.set_logger(logging.getLogger())
here = os.path.dirname(__file__)
doc_root = os.path.join(here, "docroot")
class Request(BaseRequest):
def __init__(self, *args, **kwargs):
BaseRequest.__init__(self, *args, **kwargs)
self.method = "GET"
def get_method(self):
return self.method
def add_data(self, data):
if hasattr(data, "items"):
data = urlencode(data).encode("ascii")
assert isinstance(data, binary_type)
if hasattr(BaseRequest, "add_data"):
BaseRequest.add_data(self, data)
else:
self.data = data
self.add_header("Content-Length", str(len(data)))
class TestUsingServer(unittest.TestCase):
def setUp(self):
self.server = wptserve.server.WebTestHttpd(host="localhost",
port=0,
use_ssl=False,
certificate=None,
doc_root=doc_root)
self.server.start(False)
def tearDown(self):
self.server.stop()
def abs_url(self, path, query=None):
return urlunsplit(("http", "%s:%i" % (self.server.host, self.server.port), path, query, None))
def request(self, path, query=None, method="GET", headers=None, body=None, auth=None):
req = Request(self.abs_url(path, query))
req.method = method
if headers is None:
headers = {}
for name, value in iteritems(headers):
req.add_header(name, value)
if body is not None:
req.add_data(body)
if auth is not None:
req.add_header("Authorization", b"Basic %s" % base64.b64encode(b"%s:%s" % auth))
return urlopen(req)
def assert_multiple_headers(self, resp, name, values):
if PY3:
assert resp.info().get_all(name) == values
else:
assert resp.info()[name] == ", ".join(values)
@pytest.mark.skipif(not wptserve.utils.http2_compatible(), reason="h2 server only works in python 2.7.10+ and Python 3.6+")
class TestUsingH2Server:
def setup_method(self, test_method):
self.server = wptserve.server.WebTestHttpd(host="localhost",
port=0,
use_ssl=True,
doc_root=doc_root,
key_file=os.path.join(repo_root, "tools", "certs", "web-platform.test.key"),
certificate=os.path.join(repo_root, "tools", "certs", "web-platform.test.pem"),
handler_cls=wptserve.server.Http2WebTestRequestHandler,
http2=True)
self.server.start(False)
context = tls.init_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
context.set_alpn_protocols(['h2'])
self.conn = HTTP20Connection('%s:%i' % (self.server.host, self.server.port), enable_push=True, secure=True, ssl_context=context)
self.conn.connect()
def teardown_method(self, test_method):
self.server.stop()
class TestWrapperHandlerUsingServer(TestUsingServer):
'''For a wrapper handler, a .js dummy testing file is requried to render
the html file. This class extends the TestUsingServer and do some some
extra work: it tries to generate the dummy .js file in setUp and
remove it in tearDown.'''
dummy_files = {}
def gen_file(self, filename, empty=True, content=b''):
self.remove_file(filename)
with open(filename, 'wb') as fp:
if not empty:
fp.write(content)
def remove_file(self, filename):
if os.path.exists(filename):
os.remove(filename)
def setUp(self):
super(TestWrapperHandlerUsingServer, self).setUp()
for filename, content in self.dummy_files.items():
filepath = os.path.join(doc_root, filename)
if content == '':
self.gen_file(filepath)
else:
self.gen_file(filepath, False, content)
def run_wrapper_test(self, req_file, content_type, wrapper_handler,
headers=None):
route = ('GET', req_file, wrapper_handler())
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual(content_type, resp.info()['Content-Type'])
for key, val in headers or []:
self.assertEqual(val, resp.info()[key])
with open(os.path.join(doc_root, req_file), 'rb') as fp:
self.assertEqual(fp.read(), resp.read())
def tearDown(self):
super(TestWrapperHandlerUsingServer, self).tearDown()
for filename, _ in self.dummy_files.items():
filepath = os.path.join(doc_root, filename)
self.remove_file(filepath)

Просмотреть файл

@ -0,0 +1,9 @@
self.GLOBAL = {
isWindow: function() { return false; },
isWorker: function() { return true; },
};
importScripts("/resources/testharness.js");
importScripts("/bar.any.js");
done();

Просмотреть файл

@ -0,0 +1 @@
This is a test document

Просмотреть файл

@ -0,0 +1,14 @@
<!doctype html>
<meta charset=utf-8>
<script>
self.GLOBAL = {
isWindow: function() { return true; },
isWorker: function() { return false; },
};
</script>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script src="/foo.any.js"></script>

Просмотреть файл

@ -0,0 +1,15 @@
<!doctype html>
<meta charset=utf-8>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script>
(async function() {
const scope = 'does/not/exist';
let reg = await navigator.serviceWorker.getRegistration(scope);
if (reg) await reg.unregister();
reg = await navigator.serviceWorker.register("/foo.any.worker.js", {scope});
fetch_tests_from_worker(reg.installing);
})();
</script>

Просмотреть файл

@ -0,0 +1,9 @@
<!doctype html>
<meta charset=utf-8>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script>
fetch_tests_from_worker(new SharedWorker("/foo.any.worker.js"));
</script>

Просмотреть файл

@ -0,0 +1,9 @@
<!doctype html>
<meta charset=utf-8>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script>
fetch_tests_from_worker(new Worker("/foo.any.worker.js"));
</script>

Просмотреть файл

@ -0,0 +1,8 @@
<!doctype html>
<meta charset=utf-8>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script src="/foo.window.js"></script>

Просмотреть файл

@ -0,0 +1,9 @@
<!doctype html>
<meta charset=utf-8>
<script src="/resources/testharness.js"></script>
<script src="/resources/testharnessreport.js"></script>
<div id=log></div>
<script>
fetch_tests_from_worker(new Worker("/foo.worker.js"));
</script>

Просмотреть файл

@ -0,0 +1,3 @@
# Intentional syntax error in this file
def main(request, response:
return "FAIL"

Просмотреть файл

@ -0,0 +1,3 @@
# Oops...
def mian(request, response):
return "FAIL"

Просмотреть файл

@ -0,0 +1 @@
{{host}} {{domains[]}} {{ports[http][0]}}

Просмотреть файл

@ -0,0 +1 @@
{{host}} {{domains[]}} {{ports[http][0]}}

Просмотреть файл

@ -0,0 +1,6 @@
md5: {{file_hash(md5, sub_file_hash_subject.txt)}}
sha1: {{file_hash(sha1, sub_file_hash_subject.txt)}}
sha224: {{file_hash(sha224, sub_file_hash_subject.txt)}}
sha256: {{file_hash(sha256, sub_file_hash_subject.txt)}}
sha384: {{file_hash(sha384, sub_file_hash_subject.txt)}}
sha512: {{file_hash(sha512, sub_file_hash_subject.txt)}}

Просмотреть файл

@ -0,0 +1,2 @@
This file is used to verify expected behavior of the `file_hash` "sub"
function.

Просмотреть файл

@ -0,0 +1 @@
{{file_hash(sha007, sub_file_hash_subject.txt)}}

Просмотреть файл

@ -0,0 +1,2 @@
{{header_or_default(X-Present, present-default)}}
{{header_or_default(X-Absent, absent-default)}}

Просмотреть файл

@ -0,0 +1 @@
{{headers[X-Test]}}

Просмотреть файл

@ -0,0 +1 @@
{{headers[X-Test]}}

Просмотреть файл

@ -0,0 +1,8 @@
host: {{location[host]}}
hostname: {{location[hostname]}}
path: {{location[path]}}
pathname: {{location[pathname]}}
port: {{location[port]}}
query: {{location[query]}}
scheme: {{location[scheme]}}
server: {{location[server]}}

Просмотреть файл

@ -0,0 +1 @@
{{GET[plus pct-20 pct-3D=]}}

Просмотреть файл

@ -0,0 +1 @@
{{GET[plus pct-20 pct-3D=]}}

Просмотреть файл

@ -0,0 +1 @@
Before {{url_base}} After

Просмотреть файл

@ -0,0 +1 @@
Before {{uuid()}} After

Просмотреть файл

@ -0,0 +1 @@
{{$first:host}} {{$second:ports[http][0]}} A {{$second}} B {{$first}} C

Просмотреть файл

@ -0,0 +1,2 @@
def module_function():
return [("Content-Type", "text/plain")], "PASS"

Просмотреть файл

@ -0,0 +1 @@
I am here to ensure that my containing directory exists.

Просмотреть файл

@ -0,0 +1,5 @@
from subdir import example_module
def main(request, response):
return example_module.module_function()

Просмотреть файл

@ -0,0 +1,3 @@
{{fs_path(sub_path.sub.txt)}}
{{fs_path(../sub_path.sub.txt)}}
{{fs_path(/sub_path.sub.txt)}}

Просмотреть файл

@ -0,0 +1,5 @@
HTTP/1.1 202 Giraffe
X-TEST: PASS
Content-Length: 7
Content

Просмотреть файл

@ -0,0 +1,2 @@
def handle_data(frame, request, response):
response.content = frame.data[::-1]

Просмотреть файл

@ -0,0 +1,3 @@
def handle_headers(frame, request, response):
response.status = 203
response.headers.update([('test', 'passed')])

Просмотреть файл

@ -0,0 +1,6 @@
def handle_headers(frame, request, response):
response.status = 203
response.headers.update([('test', 'passed')])
def handle_data(frame, request, response):
response.content = frame.data[::-1]

Просмотреть файл

@ -0,0 +1,3 @@
def main(request, response):
response.headers.set("Content-Type", "text/plain")
return "PASS"

Просмотреть файл

@ -0,0 +1,2 @@
def main(request, response):
return [("Content-Type", "text/html"), ("X-Test", "PASS")], "PASS"

Просмотреть файл

@ -0,0 +1,2 @@
def main(request, response):
return (202, "Giraffe"), [("Content-Type", "text/html"), ("X-Test", "PASS")], "PASS"

Просмотреть файл

@ -0,0 +1 @@
Test document with custom headers

Просмотреть файл

@ -0,0 +1,6 @@
Custom-Header: PASS
Another-Header: {{$id:uuid()}}
Same-Value-Header: {{$id}}
Double-Header: PA
Double-Header: SS
Content-Type: text/html

Просмотреть файл

@ -0,0 +1,66 @@
import unittest
import pytest
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer
class TestResponseSetCookie(TestUsingServer):
def test_name_value(self):
@wptserve.handlers.handler
def handler(request, response):
response.set_cookie(b"name", b"value")
return "Test"
route = ("GET", "/test/name_value", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(resp.info()["Set-Cookie"], "name=value; Path=/")
def test_unset(self):
@wptserve.handlers.handler
def handler(request, response):
response.set_cookie(b"name", b"value")
response.unset_cookie(b"name")
return "Test"
route = ("GET", "/test/unset", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertTrue("Set-Cookie" not in resp.info())
def test_delete(self):
@wptserve.handlers.handler
def handler(request, response):
response.delete_cookie(b"name")
return "Test"
route = ("GET", "/test/delete", handler)
self.server.router.register(*route)
resp = self.request(route[1])
parts = dict(item.split("=") for
item in resp.info()["Set-Cookie"].split("; ") if item)
self.assertEqual(parts["name"], "")
self.assertEqual(parts["Path"], "/")
# TODO: Should also check that expires is in the past
class TestRequestCookies(TestUsingServer):
def test_set_cookie(self):
@wptserve.handlers.handler
def handler(request, response):
return request.cookies[b"name"].value
route = ("GET", "/test/set_cookie", handler)
self.server.router.register(*route)
resp = self.request(route[1], headers={"Cookie": "name=value"})
self.assertEqual(resp.read(), b"value")
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,448 @@
import json
import os
import sys
import unittest
import uuid
import pytest
from six.moves.urllib.error import HTTPError
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, TestUsingH2Server, doc_root
from .base import TestWrapperHandlerUsingServer
from serve import serve
class TestFileHandler(TestUsingServer):
def test_GET(self):
resp = self.request("/document.txt")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(open(os.path.join(doc_root, "document.txt"), 'rb').read(), resp.read())
def test_headers(self):
resp = self.request("/with_headers.txt")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["Custom-Header"])
# This will fail if it isn't a valid uuid
uuid.UUID(resp.info()["Another-Header"])
self.assertEqual(resp.info()["Same-Value-Header"], resp.info()["Another-Header"])
self.assert_multiple_headers(resp, "Double-Header", ["PA", "SS"])
def test_range(self):
resp = self.request("/document.txt", headers={"Range":"bytes=10-19"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(10, len(data))
self.assertEqual("bytes 10-19/%i" % len(expected), resp.info()['Content-Range'])
self.assertEqual("10", resp.info()['Content-Length'])
self.assertEqual(expected[10:20], data)
def test_range_no_end(self):
resp = self.request("/document.txt", headers={"Range":"bytes=10-"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(len(expected) - 10, len(data))
self.assertEqual("bytes 10-%i/%i" % (len(expected) - 1, len(expected)), resp.info()['Content-Range'])
self.assertEqual(expected[10:], data)
def test_range_no_start(self):
resp = self.request("/document.txt", headers={"Range":"bytes=-10"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(10, len(data))
self.assertEqual("bytes %i-%i/%i" % (len(expected) - 10, len(expected) - 1, len(expected)),
resp.info()['Content-Range'])
self.assertEqual(expected[-10:], data)
def test_multiple_ranges(self):
resp = self.request("/document.txt", headers={"Range":"bytes=1-2,5-7,6-10"})
self.assertEqual(206, resp.getcode())
data = resp.read()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertTrue(resp.info()["Content-Type"].startswith("multipart/byteranges; boundary="))
boundary = resp.info()["Content-Type"].split("boundary=")[1]
parts = data.split(b"--" + boundary.encode("ascii"))
self.assertEqual(b"\r\n", parts[0])
self.assertEqual(b"--", parts[-1])
expected_parts = [(b"1-2", expected[1:3]), (b"5-10", expected[5:11])]
for expected_part, part in zip(expected_parts, parts[1:-1]):
header_string, body = part.split(b"\r\n\r\n")
headers = dict(item.split(b": ", 1) for item in header_string.split(b"\r\n") if item.strip())
self.assertEqual(headers[b"Content-Type"], b"text/plain")
self.assertEqual(headers[b"Content-Range"], b"bytes %s/%i" % (expected_part[0], len(expected)))
self.assertEqual(expected_part[1] + b"\r\n", body)
def test_range_invalid(self):
with self.assertRaises(HTTPError) as cm:
self.request("/document.txt", headers={"Range":"bytes=11-10"})
self.assertEqual(cm.exception.code, 416)
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
with self.assertRaises(HTTPError) as cm:
self.request("/document.txt", headers={"Range":"bytes=%i-%i" % (len(expected), len(expected) + 10)})
self.assertEqual(cm.exception.code, 416)
def test_sub_config(self):
resp = self.request("/sub.sub.txt")
expected = b"localhost localhost %i" % self.server.port
assert resp.read().rstrip() == expected
def test_sub_headers(self):
resp = self.request("/sub_headers.sub.txt", headers={"X-Test": "PASS"})
expected = b"PASS"
assert resp.read().rstrip() == expected
def test_sub_params(self):
resp = self.request("/sub_params.txt", query="plus+pct-20%20pct-3D%3D=PLUS+PCT-20%20PCT-3D%3D&pipe=sub")
expected = b"PLUS PCT-20 PCT-3D="
assert resp.read().rstrip() == expected
class TestFunctionHandler(TestUsingServer):
def test_string_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return "test data"
route = ("GET", "/test/test_string_rv", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("9", resp.info()["Content-Length"])
self.assertEqual(b"test data", resp.read())
def test_tuple_1_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return ()
route = ("GET", "/test/test_tuple_1_rv", handler)
self.server.router.register(*route)
with pytest.raises(HTTPError) as cm:
self.request(route[1])
assert cm.value.code == 500
def test_tuple_2_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return [("Content-Length", 4), ("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_2_rv", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("4", resp.info()["Content-Length"])
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test", resp.read())
def test_tuple_3_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return 202, [("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_3_rv", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test data", resp.read())
def test_tuple_3_rv_1(self):
@wptserve.handlers.handler
def handler(request, response):
return (202, "Some Status"), [("test-header", "test-value")], "test data"
route = ("GET", "/test/test_tuple_3_rv_1", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("Some Status", resp.msg)
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual(b"test data", resp.read())
def test_tuple_4_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return 202, [("test-header", "test-value")], "test data", "garbage"
route = ("GET", "/test/test_tuple_1_rv", handler)
self.server.router.register(*route)
with pytest.raises(HTTPError) as cm:
self.request(route[1])
assert cm.value.code == 500
def test_none_rv(self):
@wptserve.handlers.handler
def handler(request, response):
return None
route = ("GET", "/test/test_none_rv", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 200
assert "Content-Length" not in resp.info()
assert resp.read() == b""
class TestJSONHandler(TestUsingServer):
def test_json_0(self):
@wptserve.handlers.json_handler
def handler(request, response):
return {"data": "test data"}
route = ("GET", "/test/test_json_0", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual({"data": "test data"}, json.load(resp))
def test_json_tuple_2(self):
@wptserve.handlers.json_handler
def handler(request, response):
return [("Test-Header", "test-value")], {"data": "test data"}
route = ("GET", "/test/test_json_tuple_2", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(200, resp.getcode())
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual({"data": "test data"}, json.load(resp))
def test_json_tuple_3(self):
@wptserve.handlers.json_handler
def handler(request, response):
return (202, "Giraffe"), [("Test-Header", "test-value")], {"data": "test data"}
route = ("GET", "/test/test_json_tuple_2", handler)
self.server.router.register(*route)
resp = self.request(route[1])
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("test-value", resp.info()["test-header"])
self.assertEqual({"data": "test data"}, json.load(resp))
class TestPythonHandler(TestUsingServer):
def test_string(self):
resp = self.request("/test_string.py")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(b"PASS", resp.read())
def test_tuple_2(self):
resp = self.request("/test_tuple_2.py")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"PASS", resp.read())
def test_tuple_3(self):
resp = self.request("/test_tuple_3.py")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("text/html", resp.info()["Content-Type"])
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"PASS", resp.read())
def test_import(self):
dir_name = os.path.join(doc_root, "subdir")
assert dir_name not in sys.path
assert "test_module" not in sys.modules
resp = self.request("/subdir/import_handler.py")
assert dir_name not in sys.path
assert "test_module" not in sys.modules
self.assertEqual(200, resp.getcode())
self.assertEqual("text/plain", resp.info()["Content-Type"])
self.assertEqual(b"PASS", resp.read())
def test_no_main(self):
with pytest.raises(HTTPError) as cm:
self.request("/no_main.py")
assert cm.value.code == 500
def test_invalid(self):
with pytest.raises(HTTPError) as cm:
self.request("/invalid.py")
assert cm.value.code == 500
def test_missing(self):
with pytest.raises(HTTPError) as cm:
self.request("/missing.py")
assert cm.value.code == 404
class TestDirectoryHandler(TestUsingServer):
def test_directory(self):
resp = self.request("/")
self.assertEqual(200, resp.getcode())
self.assertEqual("text/html", resp.info()["Content-Type"])
#Add a check that the response is actually sane
def test_subdirectory_trailing_slash(self):
resp = self.request("/subdir/")
assert resp.getcode() == 200
assert resp.info()["Content-Type"] == "text/html"
def test_subdirectory_no_trailing_slash(self):
# This seems to resolve the 301 transparently, so test for 200
resp = self.request("/subdir")
assert resp.getcode() == 200
assert resp.info()["Content-Type"] == "text/html"
class TestAsIsHandler(TestUsingServer):
def test_as_is(self):
resp = self.request("/test.asis")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("PASS", resp.info()["X-Test"])
self.assertEqual(b"Content", resp.read())
#Add a check that the response is actually sane
class TestH2Handler(TestUsingH2Server):
def test_handle_headers(self):
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
def test_only_main(self):
self.conn.request("GET", '/test_tuple_3.py')
resp = self.conn.get_response()
assert resp.status == 202
assert resp.headers['Content-Type'][0] == b'text/html'
assert resp.headers['X-Test'][0] == b'PASS'
assert resp.read() == b'PASS'
def test_handle_data(self):
self.conn.request("POST", '/test_h2_data.py', body="hello world!")
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b'!dlrow olleh'
def test_handle_headers_data(self):
self.conn.request("POST", '/test_h2_headers_data.py', body="hello world!")
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b'!dlrow olleh'
def test_no_main_or_handlers(self):
self.conn.request("GET", '/no_main.py')
resp = self.conn.get_response()
assert resp.status == 500
assert "No main function or handlers in script " in json.loads(resp.read())["error"]["message"]
def test_not_found(self):
self.conn.request("GET", '/no_exist.py')
resp = self.conn.get_response()
assert resp.status == 404
def test_requesting_multiple_resources(self):
# 1st .py resource
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
# 2nd .py resource
self.conn.request("GET", '/test_tuple_3.py')
resp = self.conn.get_response()
assert resp.status == 202
assert resp.headers['Content-Type'][0] == b'text/html'
assert resp.headers['X-Test'][0] == b'PASS'
assert resp.read() == b'PASS'
# 3rd .py resource
self.conn.request("GET", '/test_h2_headers.py')
resp = self.conn.get_response()
assert resp.status == 203
assert resp.headers['test'][0] == b'passed'
assert resp.read() == b''
class TestWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.worker.js': b'',
'foo.any.js': b''}
def test_any_worker_html(self):
self.run_wrapper_test('foo.any.worker.html',
'text/html', serve.WorkersHandler)
def test_worker_html(self):
self.run_wrapper_test('foo.worker.html',
'text/html', serve.WorkersHandler)
class TestWindowHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.window.js': b''}
def test_window_html(self):
self.run_wrapper_test('foo.window.html',
'text/html', serve.WindowHandler)
class TestAnyHtmlHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'',
'foo.any.js.headers': b'X-Foo: 1',
'__dir__.headers': b'X-Bar: 2'}
def test_any_html(self):
self.run_wrapper_test('foo.any.html',
'text/html',
serve.AnyHtmlHandler,
headers=[('X-Foo', '1'), ('X-Bar', '2')])
class TestSharedWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'// META: global=sharedworker\n'}
def test_any_sharedworkers_html(self):
self.run_wrapper_test('foo.any.sharedworker.html',
'text/html', serve.SharedWorkersHandler)
class TestServiceWorkersHandler(TestWrapperHandlerUsingServer):
dummy_files = {'foo.any.js': b'// META: global=serviceworker\n'}
def test_serviceworker_html(self):
self.run_wrapper_test('foo.any.serviceworker.html',
'text/html', serve.ServiceWorkersHandler)
class TestAnyWorkerHandler(TestWrapperHandlerUsingServer):
dummy_files = {'bar.any.js': b''}
def test_any_work_js(self):
self.run_wrapper_test('bar.any.worker.js', 'text/javascript',
serve.AnyWorkerHandler)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,155 @@
import sys
from io import BytesIO
import pytest
from six import PY2
from wptserve.request import InputFile
bstr = b'This is a test document\nWith new lines\nSeveral in fact...'
rfile = ''
test_file = '' # This will be used to test the InputFile functions against
input_file = InputFile(None, 0)
def setup_function(function):
global rfile, input_file, test_file
rfile = BytesIO(bstr)
test_file = BytesIO(bstr)
input_file = InputFile(rfile, len(bstr))
def teardown_function(function):
rfile.close()
test_file.close()
def test_seek():
input_file.seek(2)
test_file.seek(2)
assert input_file.read(1) == test_file.read(1)
input_file.seek(4)
test_file.seek(4)
assert input_file.read(1) == test_file.read(1)
def test_seek_backwards():
input_file.seek(2)
test_file.seek(2)
assert input_file.tell() == test_file.tell()
assert input_file.read(1) == test_file.read(1)
assert input_file.tell() == test_file.tell()
input_file.seek(0)
test_file.seek(0)
assert input_file.read(1) == test_file.read(1)
def test_seek_negative_offset():
with pytest.raises(ValueError):
input_file.seek(-1)
def test_seek_file_bigger_than_buffer():
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
input_file = InputFile(rfile, len(bstr))
input_file.seek(2)
test_file.seek(2)
assert input_file.read(1) == test_file.read(1)
input_file.seek(4)
test_file.seek(4)
assert input_file.read(1) == test_file.read(1)
finally:
InputFile.max_buffer_size = old_max_buf
def test_read():
assert input_file.read() == test_file.read()
def test_read_file_bigger_than_buffer():
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
input_file = InputFile(rfile, len(bstr))
assert input_file.read() == test_file.read()
finally:
InputFile.max_buffer_size = old_max_buf
def test_readline():
assert input_file.readline() == test_file.readline()
assert input_file.readline() == test_file.readline()
input_file.seek(0)
test_file.seek(0)
assert input_file.readline() == test_file.readline()
def test_readline_max_byte():
line = test_file.readline()
assert input_file.readline(max_bytes=len(line)//2) == line[:len(line)//2]
assert input_file.readline(max_bytes=len(line)) == line[len(line)//2:]
def test_readline_max_byte_longer_than_file():
assert input_file.readline(max_bytes=1000) == test_file.readline()
assert input_file.readline(max_bytes=1000) == test_file.readline()
def test_readline_file_bigger_than_buffer():
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
input_file = InputFile(rfile, len(bstr))
assert input_file.readline() == test_file.readline()
assert input_file.readline() == test_file.readline()
finally:
InputFile.max_buffer_size = old_max_buf
def test_readlines():
assert input_file.readlines() == test_file.readlines()
@pytest.mark.xfail(sys.platform == "win32" and PY2,
reason="https://github.com/web-platform-tests/wpt/issues/12949")
def test_readlines_file_bigger_than_buffer():
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
input_file = InputFile(rfile, len(bstr))
assert input_file.readlines() == test_file.readlines()
finally:
InputFile.max_buffer_size = old_max_buf
def test_iter():
for a, b in zip(input_file, test_file):
assert a == b
@pytest.mark.xfail(sys.platform == "win32" and PY2,
reason="https://github.com/web-platform-tests/wpt/issues/12949")
def test_iter_file_bigger_than_buffer():
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
input_file = InputFile(rfile, len(bstr))
for a, b in zip(input_file, test_file):
assert a == b
finally:
InputFile.max_buffer_size = old_max_buf

Просмотреть файл

@ -0,0 +1,230 @@
import os
import unittest
import time
import json
from six import assertRegex
from six.moves import urllib
import pytest
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, doc_root
class TestStatus(TestUsingServer):
def test_status(self):
resp = self.request("/document.txt", query="pipe=status(202)")
self.assertEqual(resp.getcode(), 202)
class TestHeader(TestUsingServer):
def test_not_set(self):
resp = self.request("/document.txt", query="pipe=header(X-TEST,PASS)")
self.assertEqual(resp.info()["X-TEST"], "PASS")
def test_set(self):
resp = self.request("/document.txt", query="pipe=header(Content-Type,text/html)")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple(self):
resp = self.request("/document.txt", query="pipe=header(X-Test,PASS)|header(Content-Type,text/html)")
self.assertEqual(resp.info()["X-TEST"], "PASS")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple_same(self):
resp = self.request("/document.txt", query="pipe=header(Content-Type,FAIL)|header(Content-Type,text/html)")
self.assertEqual(resp.info()["Content-Type"], "text/html")
def test_multiple_append(self):
resp = self.request("/document.txt", query="pipe=header(X-Test,1)|header(X-Test,2,True)")
self.assert_multiple_headers(resp, "X-Test", ["1", "2"])
def test_semicolon(self):
resp = self.request("/document.txt", query="pipe=header(Refresh,3;url=http://example.com)")
self.assertEqual(resp.info()["Refresh"], "3;url=http://example.com")
def test_escape_comma(self):
resp = self.request("/document.txt", query=r"pipe=header(Expires,Thu\,%2014%20Aug%201986%2018:00:00%20GMT)")
self.assertEqual(resp.info()["Expires"], "Thu, 14 Aug 1986 18:00:00 GMT")
def test_escape_parenthesis(self):
resp = self.request("/document.txt", query=r"pipe=header(User-Agent,Mozilla/5.0%20(X11;%20Linux%20x86_64;%20rv:12.0\)")
self.assertEqual(resp.info()["User-Agent"], "Mozilla/5.0 (X11; Linux x86_64; rv:12.0)")
class TestSlice(TestUsingServer):
def test_both_bounds(self):
resp = self.request("/document.txt", query="pipe=slice(1,10)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[1:10])
def test_no_upper(self):
resp = self.request("/document.txt", query="pipe=slice(1)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[1:])
def test_no_lower(self):
resp = self.request("/document.txt", query="pipe=slice(null,10)")
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected[:10])
class TestSub(TestUsingServer):
def test_sub_config(self):
resp = self.request("/sub.txt", query="pipe=sub")
expected = b"localhost localhost %i" % self.server.port
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_file_hash(self):
resp = self.request("/sub_file_hash.sub.txt")
expected = b"""
md5: JmI1W8fMHfSfCarYOSxJcw==
sha1: nqpWqEw4IW8NjD6R375gtrQvtTo=
sha224: RqQ6fMmta6n9TuA/vgTZK2EqmidqnrwBAmQLRQ==
sha256: G6Ljg1uPejQxqFmvFOcV/loqnjPTW5GSOePOfM/u0jw=
sha384: lkXHChh1BXHN5nT5BYhi1x67E1CyYbPKRKoF2LTm5GivuEFpVVYtvEBHtPr74N9E
sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICzJwGFonfXwg=="""
self.assertEqual(resp.read().rstrip(), expected.strip())
def test_sub_file_hash_unrecognized(self):
with self.assertRaises(urllib.error.HTTPError):
self.request("/sub_file_hash_unrecognized.sub.txt")
def test_sub_headers(self):
resp = self.request("/sub_headers.txt", query="pipe=sub", headers={"X-Test": "PASS"})
expected = b"PASS"
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_location(self):
resp = self.request("/sub_location.sub.txt?query_string")
expected = """
host: localhost:{0}
hostname: localhost
path: /sub_location.sub.txt
pathname: /sub_location.sub.txt
port: {0}
query: ?query_string
scheme: http
server: http://localhost:{0}""".format(self.server.port).encode("ascii")
self.assertEqual(resp.read().rstrip(), expected.strip())
def test_sub_params(self):
resp = self.request("/sub_params.txt", query="plus+pct-20%20pct-3D%3D=PLUS+PCT-20%20PCT-3D%3D&pipe=sub")
expected = b"PLUS PCT-20 PCT-3D="
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_url_base(self):
resp = self.request("/sub_url_base.sub.txt")
self.assertEqual(resp.read().rstrip(), b"Before / After")
def test_sub_url_base_via_filename_with_query(self):
resp = self.request("/sub_url_base.sub.txt?pipe=slice(5,10)")
self.assertEqual(resp.read().rstrip(), b"e / A")
def test_sub_uuid(self):
resp = self.request("/sub_uuid.sub.txt")
assertRegex(self, resp.read().rstrip(), b"Before [a-f0-9-]+ After")
def test_sub_var(self):
resp = self.request("/sub_var.sub.txt")
port = self.server.port
expected = b"localhost %d A %d B localhost C" % (port, port)
self.assertEqual(resp.read().rstrip(), expected)
def test_sub_fs_path(self):
resp = self.request("/subdir/sub_path.sub.txt")
root = os.path.abspath(doc_root)
expected = """%(root)s%(sep)ssubdir%(sep)ssub_path.sub.txt
%(root)s%(sep)ssub_path.sub.txt
%(root)s%(sep)ssub_path.sub.txt
""" % {"root": root, "sep": os.path.sep}
self.assertEqual(resp.read(), expected.encode("utf8"))
def test_sub_header_or_default(self):
resp = self.request("/sub_header_or_default.sub.txt", headers={"X-Present": "OK"})
expected = b"OK\nabsent-default"
self.assertEqual(resp.read().rstrip(), expected)
class TestTrickle(TestUsingServer):
def test_trickle(self):
#Actually testing that the response trickles in is not that easy
t0 = time.time()
resp = self.request("/document.txt", query="pipe=trickle(1:d2:5:d1:r2)")
t1 = time.time()
expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
self.assertEqual(resp.read(), expected)
self.assertGreater(6, t1-t0)
def test_headers(self):
resp = self.request("/document.txt", query="pipe=trickle(d0.01)")
self.assertEqual(resp.info()["Cache-Control"], "no-cache, no-store, must-revalidate")
self.assertEqual(resp.info()["Pragma"], "no-cache")
self.assertEqual(resp.info()["Expires"], "0")
class TestPipesWithVariousHandlers(TestUsingServer):
def test_with_python_file_handler(self):
resp = self.request("/test_string.py", query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b"PA")
def test_with_python_func_handler(self):
@wptserve.handlers.handler
def handler(request, response):
return "PASS"
route = ("GET", "/test/test_pipes_1/", handler)
self.server.router.register(*route)
resp = self.request(route[1], query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b"PA")
def test_with_python_func_handler_using_response_writer(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_content("PASS")
route = ("GET", "/test/test_pipes_1/", handler)
self.server.router.register(*route)
resp = self.request(route[1], query="pipe=slice(null,2)")
# slice has not been applied to the response, because response.writer was used.
self.assertEqual(resp.read(), b"PASS")
def test_header_pipe_with_python_func_using_response_writer(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_content("CONTENT")
route = ("GET", "/test/test_pipes_1/", handler)
self.server.router.register(*route)
resp = self.request(route[1], query="pipe=header(X-TEST,FAIL)")
# header pipe was ignored, because response.writer was used.
self.assertFalse(resp.info().get("X-TEST"))
self.assertEqual(resp.read(), b"CONTENT")
def test_with_json_handler(self):
@wptserve.handlers.json_handler
def handler(request, response):
return json.dumps({'data': 'PASS'})
route = ("GET", "/test/test_pipes_2/", handler)
self.server.router.register(*route)
resp = self.request(route[1], query="pipe=slice(null,2)")
self.assertEqual(resp.read(), b'"{')
def test_slice_with_as_is_handler(self):
resp = self.request("/test.asis", query="pipe=slice(null,2)")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
self.assertEqual("PASS", resp.info()["X-Test"])
# slice has not been applied to the response, because response.writer was used.
self.assertEqual(b"Content", resp.read())
def test_headers_with_as_is_handler(self):
resp = self.request("/test.asis", query="pipe=header(X-TEST,FAIL)")
self.assertEqual(202, resp.getcode())
self.assertEqual("Giraffe", resp.msg)
# header pipe was ignored.
self.assertEqual("PASS", resp.info()["X-TEST"])
self.assertEqual(b"Content", resp.read())
def test_trickle_with_as_is_handler(self):
t0 = time.time()
resp = self.request("/test.asis", query="pipe=trickle(1:d2:5:d1:r2)")
t1 = time.time()
self.assertTrue(b'Content' in resp.read())
self.assertGreater(6, t1-t0)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
import pytest
from six import PY3
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer
from wptserve.request import InputFile
class TestInputFile(TestUsingServer):
def test_seek(self):
@wptserve.handlers.handler
def handler(request, response):
rv = []
f = request.raw_input
f.seek(5)
rv.append(f.read(2))
rv.append(b"%d" % f.tell())
f.seek(0)
rv.append(f.readline())
rv.append(b"%d" % f.tell())
rv.append(f.read(-1))
rv.append(b"%d" % f.tell())
f.seek(0)
rv.append(f.read())
f.seek(0)
rv.extend(f.readlines())
return b" ".join(rv)
route = ("POST", "/test/test_seek", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="POST", body=b"12345ab\ncdef")
self.assertEqual(200, resp.getcode())
self.assertEqual([b"ab", b"7", b"12345ab\n", b"8", b"cdef", b"12",
b"12345ab\ncdef", b"12345ab\n", b"cdef"],
resp.read().split(b" "))
def test_seek_input_longer_than_buffer(self):
@wptserve.handlers.handler
def handler(request, response):
rv = []
f = request.raw_input
f.seek(5)
rv.append(f.read(2))
rv.append(b"%d" % f.tell())
f.seek(0)
rv.append(b"%d" % f.tell())
rv.append(b"%d" % f.tell())
return b" ".join(rv)
route = ("POST", "/test/test_seek", handler)
self.server.router.register(*route)
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
resp = self.request(route[1], method="POST", body=b"1"*20)
self.assertEqual(200, resp.getcode())
self.assertEqual([b"11", b"7", b"0", b"0"],
resp.read().split(b" "))
finally:
InputFile.max_buffer_size = old_max_buf
def test_iter(self):
@wptserve.handlers.handler
def handler(request, response):
f = request.raw_input
return b" ".join(line for line in f)
route = ("POST", "/test/test_iter", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="POST", body=b"12345\nabcdef\r\nzyxwv")
self.assertEqual(200, resp.getcode())
self.assertEqual([b"12345\n", b"abcdef\r\n", b"zyxwv"], resp.read().split(b" "))
def test_iter_input_longer_than_buffer(self):
@wptserve.handlers.handler
def handler(request, response):
f = request.raw_input
return b" ".join(line for line in f)
route = ("POST", "/test/test_iter", handler)
self.server.router.register(*route)
old_max_buf = InputFile.max_buffer_size
InputFile.max_buffer_size = 10
try:
resp = self.request(route[1], method="POST", body=b"12345\nabcdef\r\nzyxwv")
self.assertEqual(200, resp.getcode())
self.assertEqual([b"12345\n", b"abcdef\r\n", b"zyxwv"], resp.read().split(b" "))
finally:
InputFile.max_buffer_size = old_max_buf
class TestRequest(TestUsingServer):
def test_body(self):
@wptserve.handlers.handler
def handler(request, response):
request.raw_input.seek(5)
return request.body
route = ("POST", "/test/test_body", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="POST", body=b"12345ab\ncdef")
self.assertEqual(b"12345ab\ncdef", resp.read())
def test_route_match(self):
@wptserve.handlers.handler
def handler(request, response):
return request.route_match["match"] + " " + request.route_match["*"]
route = ("GET", "/test/{match}_*", handler)
self.server.router.register(*route)
resp = self.request("/test/some_route")
self.assertEqual(b"some route", resp.read())
def test_non_ascii_in_headers(self):
@wptserve.handlers.handler
def handler(request, response):
return request.headers[b"foo"]
route = ("GET", "/test/test_unicode_in_headers", handler)
self.server.router.register(*route)
# Try some non-ASCII characters and the server shouldn't crash.
encoded_text = u"你好".encode("utf-8")
resp = self.request(route[1], headers={"foo": encoded_text})
self.assertEqual(encoded_text, resp.read())
# Try a different encoding from utf-8 to make sure the binary value is
# returned in verbatim.
encoded_text = u"どうも".encode("shift-jis")
resp = self.request(route[1], headers={"foo": encoded_text})
self.assertEqual(encoded_text, resp.read())
def test_non_ascii_in_GET_params(self):
@wptserve.handlers.handler
def handler(request, response):
return request.GET[b"foo"]
route = ("GET", "/test/test_unicode_in_get", handler)
self.server.router.register(*route)
# We intentionally choose an encoding that's not the default UTF-8.
encoded_text = u"どうも".encode("shift-jis")
if PY3:
from urllib.parse import quote_from_bytes
quoted = quote_from_bytes(encoded_text)
else:
from urllib import quote
quoted = quote(encoded_text)
resp = self.request(route[1], query="foo="+quoted)
self.assertEqual(encoded_text, resp.read())
def test_non_ascii_in_POST_params(self):
@wptserve.handlers.handler
def handler(request, response):
return request.POST[b"foo"]
route = ("POST", "/test/test_unicode_in_POST", handler)
self.server.router.register(*route)
# We intentionally choose an encoding that's not the default UTF-8.
encoded_text = u"どうも".encode("shift-jis")
if PY3:
from urllib.parse import quote_from_bytes
# After urlencoding, the string should only contain ASCII.
quoted = quote_from_bytes(encoded_text).encode("ascii")
else:
from urllib import quote
quoted = quote(encoded_text)
resp = self.request(route[1], method="POST", body=b"foo="+quoted)
self.assertEqual(encoded_text, resp.read())
class TestAuth(TestUsingServer):
def test_auth(self):
@wptserve.handlers.handler
def handler(request, response):
return b" ".join((request.auth.username, request.auth.password))
route = ("GET", "/test/test_auth", handler)
self.server.router.register(*route)
resp = self.request(route[1], auth=(b"test", b"PASS"))
self.assertEqual(200, resp.getcode())
self.assertEqual([b"test", b"PASS"], resp.read().split(b" "))
encoded_text = u"どうも".encode("shift-jis")
resp = self.request(route[1], auth=(encoded_text, encoded_text))
self.assertEqual(200, resp.getcode())
self.assertEqual([encoded_text, encoded_text], resp.read().split(b" "))

Просмотреть файл

@ -0,0 +1,396 @@
import os
import unittest
import json
from io import BytesIO
import pytest
from six import create_bound_method, PY3
from six.moves.http_client import BadStatusLine
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, TestUsingH2Server, doc_root
from h2.exceptions import ProtocolError
def send_body_as_header(self):
if self._response.add_required_headers:
self.write_default_headers()
self.write("X-Body: ")
self._headers_complete = True
class TestResponse(TestUsingServer):
def test_head_without_body(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.end_headers = create_bound_method(send_body_as_header,
response.writer)
return [("X-Test", "TEST")], "body\r\n"
route = ("GET", "/test/test_head_without_body", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="HEAD")
self.assertEqual("6", resp.info()['Content-Length'])
self.assertEqual("TEST", resp.info()['x-Test'])
self.assertEqual("", resp.info()['x-body'])
def test_head_with_body(self):
@wptserve.handlers.handler
def handler(request, response):
response.send_body_for_head_request = True
response.writer.end_headers = create_bound_method(send_body_as_header,
response.writer)
return [("X-Test", "TEST")], "body\r\n"
route = ("GET", "/test/test_head_with_body", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="HEAD")
self.assertEqual("6", resp.info()['Content-Length'])
self.assertEqual("TEST", resp.info()['x-Test'])
self.assertEqual("body", resp.info()['X-Body'])
def test_write_content_no_status_no_header(self):
resp_content = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_content(resp_content)
route = ("GET", "/test/test_write_content_no_status_no_header", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info()["Content-Length"] == str(len(resp_content))
assert "Date" in resp.info()
assert "Server" in resp.info()
def test_write_content_no_headers(self):
resp_content = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_status(201)
response.writer.write_content(resp_content)
route = ("GET", "/test/test_write_content_no_headers", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 201
assert resp.read() == resp_content
assert resp.info()["Content-Length"] == str(len(resp_content))
assert "Date" in resp.info()
assert "Server" in resp.info()
def test_write_content_no_status(self):
resp_content = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_header("test-header", "test-value")
response.writer.write_content(resp_content)
route = ("GET", "/test/test_write_content_no_status", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert sorted([x.lower() for x in resp.info().keys()]) == sorted(['test-header', 'date', 'server', 'content-length'])
def test_write_content_no_status_no_required_headers(self):
resp_content = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
response.add_required_headers = False
response.writer.write_header("test-header", "test-value")
response.writer.write_content(resp_content)
route = ("GET", "/test/test_write_content_no_status_no_required_headers", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info().items() == [('test-header', 'test-value')]
def test_write_content_no_status_no_headers_no_required_headers(self):
resp_content = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
response.add_required_headers = False
response.writer.write_content(resp_content)
route = ("GET", "/test/test_write_content_no_status_no_headers_no_required_headers", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 200
assert resp.read() == resp_content
assert resp.info().items() == []
def test_write_raw_content(self):
resp_content = b"HTTP/1.1 202 Giraffe\n" \
b"X-TEST: PASS\n" \
b"Content-Length: 7\n\n" \
b"Content"
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_content(resp_content)
route = ("GET", "/test/test_write_raw_content", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 202
assert resp.info()["X-TEST"] == "PASS"
assert resp.read() == b"Content"
def test_write_raw_content_file(self):
@wptserve.handlers.handler
def handler(request, response):
with open(os.path.join(doc_root, "test.asis"), 'rb') as infile:
response.writer.write_raw_content(infile)
route = ("GET", "/test/test_write_raw_content", handler)
self.server.router.register(*route)
resp = self.request(route[1])
assert resp.getcode() == 202
assert resp.info()["X-TEST"] == "PASS"
assert resp.read() == b"Content"
def test_write_raw_none(self):
@wptserve.handlers.handler
def handler(request, response):
with pytest.raises(ValueError):
response.writer.write_raw_content(None)
route = ("GET", "/test/test_write_raw_content", handler)
self.server.router.register(*route)
self.request(route[1])
def test_write_raw_contents_invalid_http(self):
resp_content = b"INVALID HTTP"
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_content(resp_content)
route = ("GET", "/test/test_write_raw_content", handler)
self.server.router.register(*route)
try:
resp = self.request(route[1])
assert resp.read() == resp_content
except BadStatusLine as e:
# In Python3, an invalid HTTP request should throw BadStatusLine.
assert PY3
assert str(e) == resp_content.decode('utf-8')
class TestH2Response(TestUsingH2Server):
def test_write_without_ending_stream(self):
data = b"TEST"
@wptserve.handlers.handler
def handler(request, response):
headers = [
('server', 'test-h2'),
('test', 'PASS'),
]
response.writer.write_headers(headers, 202)
response.writer.write_data_frame(data, False)
# Should detect stream isn't ended and call `writer.end_stream()`
route = ("GET", "/h2test/test", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 202
assert [x for x in resp.headers.items()] == [(b'server', b'test-h2'), (b'test', b'PASS')]
assert resp.read() == data
def test_push(self):
data = b"TEST"
push_data = b"PUSH TEST"
@wptserve.handlers.handler
def handler(request, response):
headers = [
('server', 'test-h2'),
('test', 'PASS'),
]
response.writer.write_headers(headers, 202)
promise_headers = [
(':method', 'GET'),
(':path', '/push-test'),
(':scheme', 'https'),
(':authority', '%s:%i' % (self.server.host, self.server.port))
]
push_headers = [
('server', 'test-h2'),
('content-length', str(len(push_data))),
('content-type', 'text'),
]
response.writer.write_push(
promise_headers,
push_stream_id=10,
status=203,
response_headers=push_headers,
response_data=push_data
)
response.writer.write_data_frame(data, True)
route = ("GET", "/h2test/test_push", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 202
assert [x for x in resp.headers.items()] == [(b'server', b'test-h2'), (b'test', b'PASS')]
assert resp.read() == data
push_promise = next(self.conn.get_pushes())
push = push_promise.get_response()
assert push_promise.path == b'/push-test'
assert push.status == 203
assert push.read() == push_data
def test_set_error(self):
@wptserve.handlers.handler
def handler(request, response):
response.set_error(503, message="Test error")
route = ("GET", "/h2test/test_set_error", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 503
assert json.loads(resp.read()) == json.loads("{\"error\": {\"message\": \"Test error\", \"code\": 503}}")
def test_file_like_response(self):
@wptserve.handlers.handler
def handler(request, response):
content = BytesIO(b"Hello, world!")
response.content = content
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b"Hello, world!"
def test_list_response(self):
@wptserve.handlers.handler
def handler(request, response):
response.content = ['hello', 'world']
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b"helloworld"
def test_content_longer_than_frame_size(self):
@wptserve.handlers.handler
def handler(request, response):
size = response.writer.get_max_payload_size()
content = "a" * (size + 5)
return [('payload_size', size)], content
route = ("GET", "/h2test/test_content_longer_than_frame_size", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 200
payload_size = int(resp.headers['payload_size'][0])
assert payload_size
assert resp.read() == b"a" * (payload_size + 5)
def test_encode(self):
@wptserve.handlers.handler
def handler(request, response):
response.encoding = "utf8"
t = response.writer.encode(u"hello")
assert t == "hello"
with pytest.raises(ValueError):
response.writer.encode(None)
route = ("GET", "/h2test/test_content_longer_than_frame_size", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
self.conn.get_response()
def test_raw_header_frame(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_header_frame([
(':status', '204'),
('server', 'TEST-H2')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 204
assert resp.headers['server'][0] == b'TEST-H2'
assert resp.read() == b''
def test_raw_header_frame_invalid(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_header_frame([
('server', 'TEST-H2'),
(':status', '204')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
with pytest.raises(ProtocolError):
# The server can send an invalid HEADER frame, which will cause a protocol error in client
self.conn.get_response()
def test_raw_data_frame(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_data_frame(data=b'Hello world', end_stream=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
sid = self.conn.request(route[0], route[1])
assert self.conn.streams[sid]._read() == b'Hello world'
def test_raw_header_continuation_frame(self):
@wptserve.handlers.handler
def handler(request, response):
response.writer.write_raw_header_frame([
(':status', '204')
])
response.writer.write_raw_continuation_frame([
('server', 'TEST-H2')
], end_headers=True)
route = ("GET", "/h2test/test_file_like_response", handler)
self.server.router.register(*route)
self.conn.request(route[0], route[1])
resp = self.conn.get_response()
assert resp.status == 204
assert resp.headers['server'][0] == b'TEST-H2'
assert resp.read() == b''
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,113 @@
import unittest
import pytest
from six.moves.urllib.error import HTTPError
wptserve = pytest.importorskip("wptserve")
from .base import TestUsingServer, TestUsingH2Server
class TestFileHandler(TestUsingServer):
def test_not_handled(self):
with self.assertRaises(HTTPError) as cm:
self.request("/not_existing")
self.assertEqual(cm.exception.code, 404)
class TestRewriter(TestUsingServer):
def test_rewrite(self):
@wptserve.handlers.handler
def handler(request, response):
return request.request_path
route = ("GET", "/test/rewritten", handler)
self.server.rewriter.register("GET", "/test/original", route[1])
self.server.router.register(*route)
resp = self.request("/test/original")
self.assertEqual(200, resp.getcode())
self.assertEqual(b"/test/rewritten", resp.read())
class TestRequestHandler(TestUsingServer):
def test_exception(self):
@wptserve.handlers.handler
def handler(request, response):
raise Exception
route = ("GET", "/test/raises", handler)
self.server.router.register(*route)
with self.assertRaises(HTTPError) as cm:
self.request("/test/raises")
self.assertEqual(cm.exception.code, 500)
def test_many_headers(self):
headers = {"X-Val%d" % i: str(i) for i in range(256)}
@wptserve.handlers.handler
def handler(request, response):
# Additional headers are added by urllib.request.
assert len(request.headers) > len(headers)
for k, v in headers.items():
assert request.headers.get(k) == \
wptserve.utils.isomorphic_encode(v)
return "OK"
route = ("GET", "/test/headers", handler)
self.server.router.register(*route)
resp = self.request("/test/headers", headers=headers)
self.assertEqual(200, resp.getcode())
class TestFileHandlerH2(TestUsingH2Server):
def test_not_handled(self):
self.conn.request("GET", "/not_existing")
resp = self.conn.get_response()
assert resp.status == 404
class TestRewriterH2(TestUsingH2Server):
def test_rewrite(self):
@wptserve.handlers.handler
def handler(request, response):
return request.request_path
route = ("GET", "/test/rewritten", handler)
self.server.rewriter.register("GET", "/test/original", route[1])
self.server.router.register(*route)
self.conn.request("GET", "/test/original")
resp = self.conn.get_response()
assert resp.status == 200
assert resp.read() == b"/test/rewritten"
class TestRequestHandlerH2(TestUsingH2Server):
def test_exception(self):
@wptserve.handlers.handler
def handler(request, response):
raise Exception
route = ("GET", "/test/raises", handler)
self.server.router.register(*route)
self.conn.request("GET", "/test/raises")
resp = self.conn.get_response()
assert resp.status == 500
def test_frame_handler_exception(self):
class handler_cls:
def frame_handler(self, request):
raise Exception
route = ("GET", "/test/raises", handler_cls())
self.server.router.register(*route)
self.conn.request("GET", "/test/raises")
resp = self.conn.get_response()
assert resp.status == 500
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -0,0 +1,44 @@
import unittest
import uuid
import pytest
wptserve = pytest.importorskip("wptserve")
from wptserve.router import any_method
from wptserve.stash import StashServer
from .base import TestUsingServer
class TestResponseSetCookie(TestUsingServer):
def run(self, result=None):
with StashServer(None, authkey=str(uuid.uuid4())):
super(TestResponseSetCookie, self).run(result)
def test_put_take(self):
@wptserve.handlers.handler
def handler(request, response):
if request.method == "POST":
request.server.stash.put(request.POST.first(b"id"), request.POST.first(b"data"))
data = "OK"
elif request.method == "GET":
data = request.server.stash.take(request.GET.first(b"id"))
if data is None:
return "NOT FOUND"
return data
id = str(uuid.uuid4())
route = (any_method, "/test/put_take", handler)
self.server.router.register(*route)
resp = self.request(route[1], method="POST", body={"id": id, "data": "Sample data"})
self.assertEqual(resp.read(), b"OK")
resp = self.request(route[1], query="id=" + id)
self.assertEqual(resp.read(), b"Sample data")
resp = self.request(route[1], query="id=" + id)
self.assertEqual(resp.read(), b"NOT FOUND")
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,385 @@
import json
import logging
import pickle
from distutils.spawn import find_executable
from logging import handlers
import pytest
config = pytest.importorskip("wptserve.config")
def test_renamed_are_renamed():
assert len(set(config._renamed_props.keys()) & set(config.ConfigBuilder._default.keys())) == 0
def test_renamed_exist():
assert set(config._renamed_props.values()).issubset(set(config.ConfigBuilder._default.keys()))
@pytest.mark.parametrize("base, override, expected", [
({"a": 1}, {"a": 2}, {"a": 2}),
({"a": 1}, {"b": 2}, {"a": 1}),
({"a": {"b": 1}}, {"a": {}}, {"a": {"b": 1}}),
({"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 2}}),
({"a": {"b": 1}}, {"a": {"b": 2, "c": 3}}, {"a": {"b": 2}}),
pytest.param({"a": {"b": 1}}, {"a": 2}, {"a": 1}, marks=pytest.mark.xfail),
pytest.param({"a": 1}, {"a": {"b": 2}}, {"a": 1}, marks=pytest.mark.xfail),
])
def test_merge_dict(base, override, expected):
assert expected == config._merge_dict(base, override)
def test_logger_created():
with config.ConfigBuilder() as c:
assert c.logger is not None
def test_logger_preserved():
logger = logging.getLogger("test_logger_preserved")
logger.setLevel(logging.DEBUG)
with config.ConfigBuilder(logger=logger) as c:
assert c.logger is logger
def test_as_dict():
with config.ConfigBuilder() as c:
assert c.as_dict() is not None
def test_as_dict_is_json():
with config.ConfigBuilder() as c:
assert json.dumps(c.as_dict()) is not None
def test_init_basic_prop():
with config.ConfigBuilder(browser_host="foo.bar") as c:
assert c.browser_host == "foo.bar"
def test_init_prefixed_prop():
with config.ConfigBuilder(doc_root="/") as c:
assert c.doc_root == "/"
def test_init_renamed_host():
logger = logging.getLogger("test_init_renamed_host")
logger.setLevel(logging.DEBUG)
handler = handlers.BufferingHandler(100)
logger.addHandler(handler)
with config.ConfigBuilder(logger=logger, host="foo.bar") as c:
assert c.logger is logger
assert len(handler.buffer) == 1
assert "browser_host" in handler.buffer[0].getMessage() # check we give the new name in the message
assert not hasattr(c, "host")
assert c.browser_host == "foo.bar"
def test_init_bogus():
with pytest.raises(TypeError) as e:
config.ConfigBuilder(foo=1, bar=2)
message = e.value.args[0]
assert "foo" in message
assert "bar" in message
def test_getitem():
with config.ConfigBuilder(browser_host="foo.bar") as c:
assert c["browser_host"] == "foo.bar"
def test_no_setitem():
with config.ConfigBuilder() as c:
with pytest.raises(TypeError):
c["browser_host"] = "foo.bar"
def test_iter():
with config.ConfigBuilder() as c:
s = set(iter(c))
assert "browser_host" in s
assert "host" not in s
assert "__getitem__" not in s
assert "_browser_host" not in s
def test_assignment():
cb = config.ConfigBuilder()
cb.browser_host = "foo.bar"
with cb as c:
assert c.browser_host == "foo.bar"
def test_update_basic():
cb = config.ConfigBuilder()
cb.update({"browser_host": "foo.bar"})
with cb as c:
assert c.browser_host == "foo.bar"
def test_update_prefixed():
cb = config.ConfigBuilder()
cb.update({"doc_root": "/"})
with cb as c:
assert c.doc_root == "/"
def test_update_renamed_host():
logger = logging.getLogger("test_update_renamed_host")
logger.setLevel(logging.DEBUG)
handler = handlers.BufferingHandler(100)
logger.addHandler(handler)
cb = config.ConfigBuilder(logger=logger)
assert cb.logger is logger
assert len(handler.buffer) == 0
cb.update({"host": "foo.bar"})
with cb as c:
assert len(handler.buffer) == 1
assert "browser_host" in handler.buffer[0].getMessage() # check we give the new name in the message
assert not hasattr(c, "host")
assert c.browser_host == "foo.bar"
def test_update_bogus():
cb = config.ConfigBuilder()
with pytest.raises(KeyError):
cb.update({"foobar": 1})
def test_ports_auto():
with config.ConfigBuilder(ports={"http": ["auto"]},
ssl={"type": "none"}) as c:
ports = c.ports
assert set(ports.keys()) == {"http"}
assert len(ports["http"]) == 1
assert isinstance(ports["http"][0], int)
def test_ports_auto_mutate():
cb = config.ConfigBuilder(ports={"http": [1001]},
ssl={"type": "none"})
cb.ports = {"http": ["auto"]}
with cb as c:
new_ports = c.ports
assert set(new_ports.keys()) == {"http"}
assert len(new_ports["http"]) == 1
assert isinstance(new_ports["http"][0], int)
def test_ports_explicit():
with config.ConfigBuilder(ports={"http": [1001]},
ssl={"type": "none"}) as c:
ports = c.ports
assert set(ports.keys()) == {"http"}
assert ports["http"] == [1001]
def test_ports_no_ssl():
with config.ConfigBuilder(ports={"http": [1001], "https": [1002], "ws": [1003], "wss": [1004]},
ssl={"type": "none"}) as c:
ports = c.ports
assert set(ports.keys()) == {"http", "ws"}
assert ports["http"] == [1001]
assert ports["ws"] == [1003]
@pytest.mark.skipif(find_executable("openssl") is None,
reason="requires OpenSSL")
def test_ports_openssl():
with config.ConfigBuilder(ports={"http": [1001], "https": [1002], "ws": [1003], "wss": [1004]},
ssl={"type": "openssl"}) as c:
ports = c.ports
assert set(ports.keys()) == {"http", "https", "ws", "wss"}
assert ports["http"] == [1001]
assert ports["https"] == [1002]
assert ports["ws"] == [1003]
assert ports["wss"] == [1004]
def test_init_doc_root():
with config.ConfigBuilder(doc_root="/") as c:
assert c.doc_root == "/"
def test_set_doc_root():
cb = config.ConfigBuilder()
cb.doc_root = "/"
with cb as c:
assert c.doc_root == "/"
def test_server_host_from_browser_host():
with config.ConfigBuilder(browser_host="foo.bar") as c:
assert c.server_host == "foo.bar"
def test_init_server_host():
with config.ConfigBuilder(server_host="foo.bar") as c:
assert c.browser_host == "localhost" # check this hasn't changed
assert c.server_host == "foo.bar"
def test_set_server_host():
cb = config.ConfigBuilder()
cb.server_host = "/"
with cb as c:
assert c.browser_host == "localhost" # check this hasn't changed
assert c.server_host == "/"
def test_domains():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
assert c.domains == {
"": {
"": "foo.bar",
"a": "a.foo.bar",
"b": "b.foo.bar",
},
"alt": {
"": "foo2.bar",
"a": "a.foo2.bar",
"b": "b.foo2.bar",
},
}
def test_not_domains():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
not_domains = c.not_domains
assert not_domains == {
"": {
"x": "x.foo.bar",
"y": "y.foo.bar",
},
"alt": {
"x": "x.foo2.bar",
"y": "y.foo2.bar",
},
}
def test_domains_not_domains_intersection():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
domains = c.domains
not_domains = c.not_domains
assert len(set(domains.keys()) ^ set(not_domains.keys())) == 0
for host in domains.keys():
host_domains = domains[host]
host_not_domains = not_domains[host]
assert len(set(host_domains.keys()) & set(host_not_domains.keys())) == 0
assert len(set(host_domains.values()) & set(host_not_domains.values())) == 0
def test_all_domains():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
all_domains = c.all_domains
assert all_domains == {
"": {
"": "foo.bar",
"a": "a.foo.bar",
"b": "b.foo.bar",
"x": "x.foo.bar",
"y": "y.foo.bar",
},
"alt": {
"": "foo2.bar",
"a": "a.foo2.bar",
"b": "b.foo2.bar",
"x": "x.foo2.bar",
"y": "y.foo2.bar",
},
}
def test_domains_set():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
domains_set = c.domains_set
assert domains_set == {
"foo.bar",
"a.foo.bar",
"b.foo.bar",
"foo2.bar",
"a.foo2.bar",
"b.foo2.bar",
}
def test_not_domains_set():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
not_domains_set = c.not_domains_set
assert not_domains_set == {
"x.foo.bar",
"y.foo.bar",
"x.foo2.bar",
"y.foo2.bar",
}
def test_all_domains_set():
with config.ConfigBuilder(browser_host="foo.bar",
alternate_hosts={"alt": "foo2.bar"},
subdomains={"a", "b"},
not_subdomains={"x", "y"}) as c:
all_domains_set = c.all_domains_set
assert all_domains_set == {
"foo.bar",
"a.foo.bar",
"b.foo.bar",
"x.foo.bar",
"y.foo.bar",
"foo2.bar",
"a.foo2.bar",
"b.foo2.bar",
"x.foo2.bar",
"y.foo2.bar",
}
def test_ssl_env_none():
with config.ConfigBuilder(ssl={"type": "none"}) as c:
assert c.ssl_config is None
def test_ssl_env_openssl():
# TODO: this currently actually tries to start OpenSSL, which isn't ideal
# with config.ConfigBuilder(ssl={"type": "openssl", "openssl": {"openssl_binary": "foobar"}}) as c:
# assert c.ssl_env is not None
# assert c.ssl_env.ssl_enabled is True
# assert c.ssl_env.binary == "foobar"
pass
def test_ssl_env_bogus():
with pytest.raises(ValueError):
with config.ConfigBuilder(ssl={"type": "foobar"}):
pass
def test_pickle():
# Ensure that the config object can be pickled
with config.ConfigBuilder() as c:
pickle.dumps(c)

Просмотреть файл

@ -0,0 +1,40 @@
from __future__ import unicode_literals
import pytest
from wptserve.pipes import ReplacementTokenizer
@pytest.mark.parametrize(
"content,expected",
[
[b"aaa", [('ident', 'aaa')]],
[b"bbb()", [('ident', 'bbb'), ('arguments', [])]],
[b"bcd(uvw, xyz)", [('ident', 'bcd'), ('arguments', ['uvw', 'xyz'])]],
[b"$ccc:ddd", [('var', '$ccc'), ('ident', 'ddd')]],
[b"$eee", [('ident', '$eee')]],
[b"fff[0]", [('ident', 'fff'), ('index', 0)]],
[b"ggg[hhh]", [('ident', 'ggg'), ('index', 'hhh')]],
[b"[iii]", [('index', 'iii')]],
[b"jjj['kkk']", [('ident', 'jjj'), ('index', "'kkk'")]],
[b"lll[]", [('ident', 'lll'), ('index', "")]],
[b"111", [('ident', '111')]],
[b"$111", [('ident', '$111')]],
]
)
def test_tokenizer(content, expected):
tokenizer = ReplacementTokenizer()
tokens = tokenizer.tokenize(content)
assert expected == tokens
@pytest.mark.parametrize(
"content,expected",
[
[b"/", []],
[b"$aaa: BBB", [('var', '$aaa')]],
]
)
def test_tokenizer_errors(content, expected):
tokenizer = ReplacementTokenizer()
tokens = tokenizer.tokenize(content)
assert expected == tokens

Просмотреть файл

@ -0,0 +1,105 @@
import mock
from six import binary_type
from wptserve.request import Request, RequestHeaders, MultiDict
class MockHTTPMessage(dict):
"""A minimum (and not completely correctly) mock of HTTPMessage for testing.
Constructing HTTPMessage is annoying and different in Python 2 and 3. This
only implements the parts used by RequestHeaders.
Requirements for construction:
* Keys are header names and MUST be lower-case.
* Values are lists of header values (even if there's only one).
* Keys and values should be native strings to match stdlib's behaviours.
"""
def __getitem__(self, key):
assert isinstance(key, str)
values = dict.__getitem__(self, key.lower())
assert isinstance(values, list)
return values[0]
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def getallmatchingheaders(self, key):
values = dict.__getitem__(self, key.lower())
return ["{}: {}\n".format(key, v) for v in values]
def test_request_headers_get():
raw_headers = MockHTTPMessage({
'x-foo': ['foo'],
'x-bar': ['bar1', 'bar2'],
})
headers = RequestHeaders(raw_headers)
assert headers['x-foo'] == b'foo'
assert headers['X-Bar'] == b'bar1, bar2'
assert headers.get('x-bar') == b'bar1, bar2'
def test_request_headers_encoding():
raw_headers = MockHTTPMessage({
'x-foo': ['foo'],
'x-bar': ['bar1', 'bar2'],
})
headers = RequestHeaders(raw_headers)
assert isinstance(headers['x-foo'], binary_type)
assert isinstance(headers['x-bar'], binary_type)
assert isinstance(headers.get_list('x-bar')[0], binary_type)
def test_request_url_from_server_address():
request_handler = mock.Mock()
request_handler.server.scheme = 'http'
request_handler.server.server_address = ('localhost', '8000')
request_handler.path = '/demo'
request_handler.headers = MockHTTPMessage()
request = Request(request_handler)
assert request.url == 'http://localhost:8000/demo'
assert isinstance(request.url, str)
def test_request_url_from_host_header():
request_handler = mock.Mock()
request_handler.server.scheme = 'http'
request_handler.server.server_address = ('localhost', '8000')
request_handler.path = '/demo'
request_handler.headers = MockHTTPMessage({'host': ['web-platform.test:8001']})
request = Request(request_handler)
assert request.url == 'http://web-platform.test:8001/demo'
assert isinstance(request.url, str)
def test_multidict():
m = MultiDict()
m["foo"] = "bar"
m["bar"] = "baz"
m.add("foo", "baz")
m.add("baz", "qux")
assert m["foo"] == "bar"
assert m.get("foo") == "bar"
assert m["bar"] == "baz"
assert m.get("bar") == "baz"
assert m["baz"] == "qux"
assert m.get("baz") == "qux"
assert m.first("foo") == "bar"
assert m.last("foo") == "baz"
assert m.get_list("foo") == ["bar", "baz"]
assert m.get_list("non_existent") == []
assert m.get("non_existent") is None
try:
m["non_existent"]
assert False, "An exception should be raised"
except KeyError:
pass

Просмотреть файл

@ -0,0 +1,32 @@
import mock
from six import BytesIO
from wptserve.response import Response
def test_response_status():
cases = [200, (200, b'OK'), (200, u'OK'), ('200', 'OK')]
for case in cases:
handler = mock.Mock()
handler.wfile = BytesIO()
request = mock.Mock()
request.protocol_version = 'HTTP/1.1'
response = Response(handler, request)
response.status = case
expected = case if isinstance(case, tuple) else (case, None)
if expected[0] == '200':
expected = (200, expected[1])
assert response.status == expected
response.writer.write_status(*response.status)
assert handler.wfile.getvalue() == b'HTTP/1.1 200 OK\r\n'
def test_response_status_not_string():
# This behaviour is not documented but kept for backward compatibility.
handler = mock.Mock()
request = mock.Mock()
response = Response(handler, request)
response.status = (200, 100)
assert response.status == (200, '100')

Просмотреть файл

@ -0,0 +1,147 @@
import multiprocessing
import threading
import sys
from multiprocessing.managers import BaseManager
import pytest
from six import PY3
Stash = pytest.importorskip("wptserve.stash").Stash
@pytest.fixture()
def add_cleanup():
fns = []
def add(fn):
fns.append(fn)
yield add
for fn in fns:
fn()
def run(process_queue, request_lock, response_lock):
"""Create two Stash instances in parallel threads. Use the provided locks
to ensure the first thread is actively establishing an interprocess
communication channel at the moment the second thread executes."""
def target(thread_queue):
stash = Stash("/", ("localhost", 4543), b"some key")
# The `lock` property of the Stash instance should always be set
# immediately following initialization. These values are asserted in
# the active test.
thread_queue.put(stash.lock is None)
thread_queue = multiprocessing.Queue()
first = threading.Thread(target=target, args=(thread_queue,))
second = threading.Thread(target=target, args=(thread_queue,))
request_lock.acquire()
response_lock.acquire()
first.start()
request_lock.acquire()
# At this moment, the `first` thread is waiting for a proxied object.
# Create a second thread in order to inspect the behavior of the Stash
# constructor at this moment.
second.start()
# Allow the `first` thread to proceed
response_lock.release()
# Wait for both threads to complete and report their stateto the test
process_queue.put(thread_queue.get())
process_queue.put(thread_queue.get())
class SlowLock(BaseManager):
# This can only be used in test_delayed_lock since that test modifies the
# class body, but it has to be a global for multiprocessing
pass
@pytest.mark.xfail(sys.platform == "win32" or
PY3 and multiprocessing.get_start_method() == "spawn",
reason="https://github.com/web-platform-tests/wpt/issues/16938")
def test_delayed_lock(add_cleanup):
"""Ensure that delays in proxied Lock retrieval do not interfere with
initialization in parallel threads."""
request_lock = multiprocessing.Lock()
response_lock = multiprocessing.Lock()
queue = multiprocessing.Queue()
def mutex_lock_request():
"""This request handler allows the caller to delay execution of a
thread which has requested a proxied representation of the `lock`
property, simulating a "slow" interprocess communication channel."""
request_lock.release()
response_lock.acquire()
return threading.Lock()
SlowLock.register("get_dict", callable=lambda: {})
SlowLock.register("Lock", callable=mutex_lock_request)
slowlock = SlowLock(("localhost", 4543), b"some key")
slowlock.start()
add_cleanup(lambda: slowlock.shutdown())
parallel = multiprocessing.Process(target=run,
args=(queue, request_lock, response_lock))
parallel.start()
add_cleanup(lambda: parallel.terminate())
assert [queue.get(), queue.get()] == [False, False], (
"both instances had valid locks")
class SlowDict(BaseManager):
# This can only be used in test_delayed_dict since that test modifies the
# class body, but it has to be a global for multiprocessing
pass
@pytest.mark.xfail(sys.platform == "win32" or
PY3 and multiprocessing.get_start_method() == "spawn",
reason="https://github.com/web-platform-tests/wpt/issues/16938")
def test_delayed_dict(add_cleanup):
"""Ensure that delays in proxied `dict` retrieval do not interfere with
initialization in parallel threads."""
request_lock = multiprocessing.Lock()
response_lock = multiprocessing.Lock()
queue = multiprocessing.Queue()
# This request handler allows the caller to delay execution of a thread
# which has requested a proxied representation of the "get_dict" property.
def mutex_dict_request():
"""This request handler allows the caller to delay execution of a
thread which has requested a proxied representation of the `get_dict`
property, simulating a "slow" interprocess communication channel."""
request_lock.release()
response_lock.acquire()
return {}
SlowDict.register("get_dict", callable=mutex_dict_request)
SlowDict.register("Lock", callable=lambda: threading.Lock())
slowdict = SlowDict(("localhost", 4543), b"some key")
slowdict.start()
add_cleanup(lambda: slowdict.shutdown())
parallel = multiprocessing.Process(target=run,
args=(queue, request_lock, response_lock))
parallel.start()
add_cleanup(lambda: parallel.terminate())
assert [queue.get(), queue.get()] == [False, False], (
"both instances had valid locks")

Просмотреть файл

@ -0,0 +1,3 @@
from .server import WebTestHttpd, WebTestServer, Router # noqa: F401
from .request import Request # noqa: F401
from .response import Response # noqa: F401

Просмотреть файл

@ -0,0 +1,353 @@
import copy
import logging
import os
from collections import defaultdict
from six.moves.collections_abc import Mapping
from six import integer_types, iteritems, itervalues, string_types
from . import sslutils
from .utils import get_port
_renamed_props = {
"host": "browser_host",
"bind_hostname": "bind_address",
"external_host": "server_host",
"host_ip": "server_host",
}
def _merge_dict(base_dict, override_dict):
rv = base_dict.copy()
for key, value in iteritems(base_dict):
if key in override_dict:
if isinstance(value, dict):
rv[key] = _merge_dict(value, override_dict[key])
else:
rv[key] = override_dict[key]
return rv
class Config(Mapping):
"""wptserve config
Inherits from Mapping for backwards compatibility with the old dict-based config"""
def __init__(self, logger_name, data):
self.__dict__["_logger_name"] = logger_name
self.__dict__.update(data)
def __str__(self):
return str(self.__dict__)
def __setattr__(self, key, value):
raise ValueError("Config is immutable")
def __setitem__(self, key):
raise ValueError("Config is immutable")
def __getitem__(self, key):
try:
return getattr(self, key)
except AttributeError:
raise ValueError
def __contains__(self, key):
return key in self.__dict__
def __iter__(self):
return (x for x in self.__dict__ if not x.startswith("_"))
def __len__(self):
return len([item for item in self])
@property
def logger(self):
logger = logging.getLogger(self._logger_name)
logger.setLevel(self.log_level.upper())
return logger
def as_dict(self):
return json_types(self.__dict__)
# Environment variables are limited in size so we need to prune the most egregious contributors
# to size, the origin policy subdomains.
def as_dict_for_wd_env_variable(self):
result = self.as_dict()
for key in [
("subdomains",),
("domains", "alt"),
("domains", ""),
("all_domains", "alt"),
("all_domains", ""),
("domains_set",),
("all_domains_set",)
]:
target = result
for part in key[:-1]:
target = target[part]
value = target[key[-1]]
if isinstance(value, dict):
target[key[-1]] = {k:v for (k,v) in iteritems(value) if not k.startswith("op")}
else:
target[key[-1]] = [x for x in value if not x.startswith("op")]
return result
def json_types(obj):
if isinstance(obj, dict):
return {key: json_types(value) for key, value in iteritems(obj)}
if (isinstance(obj, string_types) or
isinstance(obj, integer_types) or
isinstance(obj, float) or
isinstance(obj, bool) or
obj is None):
return obj
if isinstance(obj, list) or hasattr(obj, "__iter__"):
return [json_types(value) for value in obj]
raise ValueError
class ConfigBuilder(object):
"""Builder object for setting the wptsync config.
Configuration can be passed in as a dictionary to the constructor, or
set via attributes after construction. Configuration options must match
the keys on the _default class property.
The generated configuration is obtained by using the builder
object as a context manager; this returns a Config object
containing immutable configuration that may be shared between
threads and processes. In general the configuration is only valid
for the context used to obtain it.
with ConfigBuilder() as config:
# Use the configuration
print config.browser_host
The properties on the final configuration include those explicitly
supplied and computed properties. The computed properties are
defined by the computed_properties attribute on the class. This
is a list of property names, each corresponding to a _get_<name>
method on the class. These methods are called in the order defined
in computed_properties and are passed a single argument, a
dictionary containing the current set of properties. Thus computed
properties later in the list may depend on the value of earlier
ones.
"""
_default = {
"browser_host": "localhost",
"alternate_hosts": {},
"doc_root": os.path.dirname("__file__"),
"server_host": None,
"ports": {"http": [8000]},
"check_subdomains": True,
"log_level": "debug",
"bind_address": True,
"ssl": {
"type": "none",
"encrypt_after_connect": False,
"none": {},
"openssl": {
"openssl_binary": "openssl",
"base_path": "_certs",
"password": "web-platform-tests",
"force_regenerate": False,
"duration": 30,
"base_conf_path": None
},
"pregenerated": {
"host_key_path": None,
"host_cert_path": None,
},
},
"aliases": []
}
default_config_cls = Config
# Configuration properties that are computed. Each corresponds to a method
# _get_foo, which is called with the current data dictionary. The properties
# are computed in the order specified in the list.
computed_properties = ["log_level",
"paths",
"server_host",
"ports",
"domains",
"not_domains",
"all_domains",
"domains_set",
"not_domains_set",
"all_domains_set",
"ssl_config"]
def __init__(self,
logger=None,
subdomains=set(),
not_subdomains=set(),
config_cls=None,
**kwargs):
self._data = self._default.copy()
self._ssl_env = None
self._config_cls = config_cls or self.default_config_cls
if logger is None:
self._logger_name = "web-platform-tests"
else:
level_name = logging.getLevelName(logger.level)
if level_name != "NOTSET":
self.log_level = level_name
self._logger_name = logger.name
for k, v in iteritems(self._default):
self._data[k] = kwargs.pop(k, v)
self._data["subdomains"] = subdomains
self._data["not_subdomains"] = not_subdomains
for k, new_k in iteritems(_renamed_props):
if k in kwargs:
self.logger.warning(
"%s in config is deprecated; use %s instead" % (
k,
new_k
)
)
self._data[new_k] = kwargs.pop(k)
if kwargs:
raise TypeError("__init__() got unexpected keyword arguments %r" % (tuple(kwargs),))
def __setattr__(self, key, value):
if not key[0] == "_":
self._data[key] = value
else:
self.__dict__[key] = value
@property
def logger(self):
logger = logging.getLogger(self._logger_name)
logger.setLevel(self._data["log_level"].upper())
return logger
def update(self, override):
"""Load an overrides dict to override config values"""
override = override.copy()
for k in self._default:
if k in override:
self._set_override(k, override.pop(k))
for k, new_k in iteritems(_renamed_props):
if k in override:
self.logger.warning(
"%s in config is deprecated; use %s instead" % (
k,
new_k
)
)
self._set_override(new_k, override.pop(k))
if override:
k = next(iter(override))
raise KeyError("unknown config override '%s'" % k)
def _set_override(self, k, v):
old_v = self._data[k]
if isinstance(old_v, dict):
self._data[k] = _merge_dict(old_v, v)
else:
self._data[k] = v
def __enter__(self):
if self._ssl_env is not None:
raise ValueError("Tried to re-enter configuration")
data = self._data.copy()
prefix = "_get_"
for key in self.computed_properties:
data[key] = getattr(self, prefix + key)(data)
return self._config_cls(self._logger_name, data)
def __exit__(self, *args):
self._ssl_env.__exit__(*args)
self._ssl_env = None
def _get_log_level(self, data):
return data["log_level"].upper()
def _get_paths(self, data):
return {"doc_root": data["doc_root"]}
def _get_server_host(self, data):
return data["server_host"] if data.get("server_host") is not None else data["browser_host"]
def _get_ports(self, data):
new_ports = defaultdict(list)
for scheme, ports in iteritems(data["ports"]):
if scheme in ["wss", "https"] and not sslutils.get_cls(data["ssl"]["type"]).ssl_enabled:
continue
for i, port in enumerate(ports):
real_port = get_port("") if port == "auto" else port
new_ports[scheme].append(real_port)
return new_ports
def _get_domains(self, data):
hosts = data["alternate_hosts"].copy()
assert "" not in hosts
hosts[""] = data["browser_host"]
rv = {}
for name, host in iteritems(hosts):
rv[name] = {subdomain: (subdomain.encode("idna").decode("ascii") + u"." + host)
for subdomain in data["subdomains"]}
rv[name][""] = host
return rv
def _get_not_domains(self, data):
hosts = data["alternate_hosts"].copy()
assert "" not in hosts
hosts[""] = data["browser_host"]
rv = {}
for name, host in iteritems(hosts):
rv[name] = {subdomain: (subdomain.encode("idna").decode("ascii") + u"." + host)
for subdomain in data["not_subdomains"]}
return rv
def _get_all_domains(self, data):
rv = copy.deepcopy(data["domains"])
nd = data["not_domains"]
for host in rv:
rv[host].update(nd[host])
return rv
def _get_domains_set(self, data):
return {domain
for per_host_domains in itervalues(data["domains"])
for domain in itervalues(per_host_domains)}
def _get_not_domains_set(self, data):
return {domain
for per_host_domains in itervalues(data["not_domains"])
for domain in itervalues(per_host_domains)}
def _get_all_domains_set(self, data):
return data["domains_set"] | data["not_domains_set"]
def _get_ssl_config(self, data):
ssl_type = data["ssl"]["type"]
ssl_cls = sslutils.get_cls(ssl_type)
kwargs = data["ssl"].get(ssl_type, {})
self._ssl_env = ssl_cls(self.logger, **kwargs)
self._ssl_env.__enter__()
if self._ssl_env.ssl_enabled:
key_path, cert_path = self._ssl_env.host_cert_path(data["domains_set"])
ca_cert_path = self._ssl_env.ca_cert_path(data["domains_set"])
return {"key_path": key_path,
"ca_cert_path": ca_cert_path,
"cert_path": cert_path,
"encrypt_after_connect": data["ssl"].get("encrypt_after_connect", False)}

Просмотреть файл

@ -0,0 +1,97 @@
from . import utils
content_types = utils.invert_dict({
"application/json": ["json"],
"application/wasm": ["wasm"],
"application/xhtml+xml": ["xht", "xhtm", "xhtml"],
"application/xml": ["xml"],
"application/x-xpinstall": ["xpi"],
"audio/mp4": ["m4a"],
"audio/mpeg": ["mp3"],
"audio/ogg": ["oga"],
"audio/webm": ["weba"],
"audio/x-wav": ["wav"],
"image/bmp": ["bmp"],
"image/gif": ["gif"],
"image/jpeg": ["jpg", "jpeg"],
"image/png": ["png"],
"image/svg+xml": ["svg"],
"text/cache-manifest": ["manifest"],
"text/css": ["css"],
"text/event-stream": ["event_stream"],
"text/html": ["htm", "html"],
"text/javascript": ["js", "mjs"],
"text/plain": ["txt", "md"],
"text/vtt": ["vtt"],
"video/mp4": ["mp4", "m4v"],
"video/ogg": ["ogg", "ogv"],
"video/webm": ["webm"],
})
response_codes = {
100: ('Continue', 'Request received, please continue'),
101: ('Switching Protocols',
'Switching to new protocol; obey Upgrade header'),
200: ('OK', 'Request fulfilled, document follows'),
201: ('Created', 'Document created, URL follows'),
202: ('Accepted',
'Request accepted, processing continues off-line'),
203: ('Non-Authoritative Information', 'Request fulfilled from cache'),
204: ('No Content', 'Request fulfilled, nothing follows'),
205: ('Reset Content', 'Clear input form for further input.'),
206: ('Partial Content', 'Partial content follows.'),
300: ('Multiple Choices',
'Object has several resources -- see URI list'),
301: ('Moved Permanently', 'Object moved permanently -- see URI list'),
302: ('Found', 'Object moved temporarily -- see URI list'),
303: ('See Other', 'Object moved -- see Method and URL list'),
304: ('Not Modified',
'Document has not changed since given time'),
305: ('Use Proxy',
'You must use proxy specified in Location to access this '
'resource.'),
307: ('Temporary Redirect',
'Object moved temporarily -- see URI list'),
400: ('Bad Request',
'Bad request syntax or unsupported method'),
401: ('Unauthorized',
'No permission -- see authorization schemes'),
402: ('Payment Required',
'No payment -- see charging schemes'),
403: ('Forbidden',
'Request forbidden -- authorization will not help'),
404: ('Not Found', 'Nothing matches the given URI'),
405: ('Method Not Allowed',
'Specified method is invalid for this resource.'),
406: ('Not Acceptable', 'URI not available in preferred format.'),
407: ('Proxy Authentication Required', 'You must authenticate with '
'this proxy before proceeding.'),
408: ('Request Timeout', 'Request timed out; try again later.'),
409: ('Conflict', 'Request conflict.'),
410: ('Gone',
'URI no longer exists and has been permanently removed.'),
411: ('Length Required', 'Client must specify Content-Length.'),
412: ('Precondition Failed', 'Precondition in headers is false.'),
413: ('Request Entity Too Large', 'Entity is too large.'),
414: ('Request-URI Too Long', 'URI is too long.'),
415: ('Unsupported Media Type', 'Entity body in unsupported format.'),
416: ('Requested Range Not Satisfiable',
'Cannot satisfy request range.'),
417: ('Expectation Failed',
'Expect condition could not be satisfied.'),
500: ('Internal Server Error', 'Server got itself in trouble'),
501: ('Not Implemented',
'Server does not support this operation'),
502: ('Bad Gateway', 'Invalid responses from another server/proxy.'),
503: ('Service Unavailable',
'The server cannot process the request due to a high load'),
504: ('Gateway Timeout',
'The gateway server did not receive a timely response'),
505: ('HTTP Version Not Supported', 'Cannot fulfill request.'),
}
h2_headers = ['method', 'scheme', 'host', 'path', 'authority', 'status']

Просмотреть файл

@ -0,0 +1,514 @@
import json
import os
import traceback
from collections import defaultdict
from six.moves.urllib.parse import quote, unquote, urljoin
from six import iteritems
from .constants import content_types
from .pipes import Pipeline, template
from .ranges import RangeParser
from .request import Authentication
from .response import MultipartContent
from .utils import HTTPException
try:
from html import escape
except ImportError:
from cgi import escape
__all__ = ["file_handler", "python_script_handler",
"FunctionHandler", "handler", "json_handler",
"as_is_handler", "ErrorHandler", "BasicAuthHandler"]
def guess_content_type(path):
ext = os.path.splitext(path)[1].lstrip(".")
if ext in content_types:
return content_types[ext]
return "application/octet-stream"
def filesystem_path(base_path, request, url_base="/"):
if base_path is None:
base_path = request.doc_root
path = unquote(request.url_parts.path)
if path.startswith(url_base):
path = path[len(url_base):]
if ".." in path:
raise HTTPException(404)
new_path = os.path.join(base_path, path)
# Otherwise setting path to / allows access outside the root directory
if not new_path.startswith(base_path):
raise HTTPException(404)
return new_path
class DirectoryHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def __call__(self, request, response):
url_path = request.url_parts.path
if not url_path.endswith("/"):
response.status = 301
response.headers = [("Location", "%s/" % request.url)]
return
path = filesystem_path(self.base_path, request, self.url_base)
assert os.path.isdir(path)
response.headers = [("Content-Type", "text/html")]
response.content = """<!doctype html>
<meta name="viewport" content="width=device-width">
<title>Directory listing for %(path)s</title>
<h1>Directory listing for %(path)s</h1>
<ul>
%(items)s
</ul>
""" % {"path": escape(url_path),
"items": "\n".join(self.list_items(url_path, path))} # noqa: E122
def list_items(self, base_path, path):
assert base_path.endswith("/")
# TODO: this won't actually list all routes, only the
# ones that correspond to a real filesystem path. It's
# not possible to list every route that will match
# something, but it should be possible to at least list the
# statically defined ones
if base_path != "/":
link = urljoin(base_path, "..")
yield ("""<li class="dir"><a href="%(link)s">%(name)s</a></li>""" %
{"link": link, "name": ".."})
items = []
prev_item = None
# This ensures that .headers always sorts after the file it provides the headers for. E.g.,
# if we have x, x-y, and x.headers, the order will be x, x.headers, and then x-y.
for item in sorted(os.listdir(path), key=lambda x: (x[:-len(".headers")], x) if x.endswith(".headers") else (x, x)):
if prev_item and prev_item + ".headers" == item:
items[-1][1] = item
prev_item = None
continue
items.append([item, None])
prev_item = item
for item, dot_headers in items:
link = escape(quote(item))
dot_headers_markup = ""
if dot_headers is not None:
dot_headers_markup = (""" (<a href="%(link)s">.headers</a>)""" %
{"link": escape(quote(dot_headers))})
if os.path.isdir(os.path.join(path, item)):
link += "/"
class_ = "dir"
else:
class_ = "file"
yield ("""<li class="%(class)s"><a href="%(link)s">%(name)s</a>%(headers)s</li>""" %
{"link": link, "name": escape(item), "class": class_,
"headers": dot_headers_markup})
def parse_qs(qs):
"""Parse a query string given as a string argument (data of type
application/x-www-form-urlencoded). Data are returned as a dictionary. The
dictionary keys are the unique query variable names and the values are
lists of values for each name.
This implementation is used instead of Python's built-in `parse_qs` method
in order to support the semicolon character (which the built-in method
interprets as a parameter delimiter)."""
pairs = [item.split("=", 1) for item in qs.split('&') if item]
rv = defaultdict(list)
for pair in pairs:
if len(pair) == 1 or len(pair[1]) == 0:
continue
name = unquote(pair[0].replace('+', ' '))
value = unquote(pair[1].replace('+', ' '))
rv[name].append(value)
return dict(rv)
def wrap_pipeline(path, request, response):
"""Applies pipelines to a response.
Pipelines are specified in the filename (.sub.) or the query param (?pipe).
"""
query = parse_qs(request.url_parts.query)
pipe_string = ""
if ".sub." in path:
ml_extensions = {".html", ".htm", ".xht", ".xhtml", ".xml", ".svg"}
escape_type = "html" if os.path.splitext(path)[1] in ml_extensions else "none"
pipe_string = "sub(%s)" % escape_type
if "pipe" in query:
if pipe_string:
pipe_string += "|"
pipe_string += query["pipe"][-1]
if pipe_string:
response = Pipeline(pipe_string)(request, response)
return response
def load_headers(request, path):
"""Loads headers from files for a given path.
Attempts to load both the neighbouring __dir__{.sub}.headers and
PATH{.sub}.headers (applying template substitution if needed); results are
concatenated in that order.
"""
def _load(request, path):
headers_path = path + ".sub.headers"
if os.path.exists(headers_path):
use_sub = True
else:
headers_path = path + ".headers"
use_sub = False
try:
with open(headers_path, "rb") as headers_file:
data = headers_file.read()
except IOError:
return []
else:
if use_sub:
data = template(request, data, escape_type="none")
return [tuple(item.strip() for item in line.split(b":", 1))
for line in data.splitlines() if line]
return (_load(request, os.path.join(os.path.dirname(path), "__dir__")) +
_load(request, path))
class FileHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
self.directory_handler = DirectoryHandler(self.base_path, self.url_base)
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def __call__(self, request, response):
path = filesystem_path(self.base_path, request, self.url_base)
if os.path.isdir(path):
return self.directory_handler(request, response)
try:
#This is probably racy with some other process trying to change the file
file_size = os.stat(path).st_size
response.headers.update(self.get_headers(request, path))
if "Range" in request.headers:
try:
byte_ranges = RangeParser()(request.headers['Range'], file_size)
except HTTPException as e:
if e.code == 416:
response.headers.set("Content-Range", "bytes */%i" % file_size)
raise
else:
byte_ranges = None
data = self.get_data(response, path, byte_ranges)
response.content = data
response = wrap_pipeline(path, request, response)
return response
except (OSError, IOError):
raise HTTPException(404)
def get_headers(self, request, path):
rv = load_headers(request, path)
if not any(key.lower() == b"content-type" for (key, _) in rv):
rv.insert(0, (b"Content-Type", guess_content_type(path).encode("ascii")))
return rv
def get_data(self, response, path, byte_ranges):
"""Return either the handle to a file, or a string containing
the content of a chunk of the file, if we have a range request."""
if byte_ranges is None:
return open(path, 'rb')
else:
with open(path, 'rb') as f:
response.status = 206
if len(byte_ranges) > 1:
parts_content_type, content = self.set_response_multipart(response,
byte_ranges,
f)
for byte_range in byte_ranges:
content.append_part(self.get_range_data(f, byte_range),
parts_content_type,
[("Content-Range", byte_range.header_value())])
return content
else:
response.headers.set("Content-Range", byte_ranges[0].header_value())
return self.get_range_data(f, byte_ranges[0])
def set_response_multipart(self, response, ranges, f):
parts_content_type = response.headers.get("Content-Type")
if parts_content_type:
parts_content_type = parts_content_type[-1]
else:
parts_content_type = None
content = MultipartContent()
response.headers.set("Content-Type", "multipart/byteranges; boundary=%s" % content.boundary)
return parts_content_type, content
def get_range_data(self, f, byte_range):
f.seek(byte_range.lower)
return f.read(byte_range.upper - byte_range.lower)
file_handler = FileHandler()
class PythonScriptHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __repr__(self):
return "<%s base_path:%s url_base:%s>" % (self.__class__.__name__, self.base_path, self.url_base)
def _load_file(self, request, response, func):
"""
This loads the requested python file as an environ variable.
Once the environ is loaded, the passed `func` is run with this loaded environ.
:param request: The request object
:param response: The response object
:param func: The function to be run with the loaded environ with the modified filepath. Signature: (request, response, environ, path)
:return: The return of func
"""
path = filesystem_path(self.base_path, request, self.url_base)
try:
environ = {"__file__": path}
with open(path, 'rb') as f:
exec(compile(f.read(), path, 'exec'), environ, environ)
if func is not None:
return func(request, response, environ, path)
except IOError:
raise HTTPException(404)
def __call__(self, request, response):
def func(request, response, environ, path):
if "main" in environ:
handler = FunctionHandler(environ["main"])
handler(request, response)
wrap_pipeline(path, request, response)
else:
raise HTTPException(500, "No main function in script %s" % path)
self._load_file(request, response, func)
def frame_handler(self, request):
"""
This creates a FunctionHandler with one or more of the handling functions.
Used by the H2 server.
:param request: The request object used to generate the handler.
:return: A FunctionHandler object with one or more of these functions: `handle_headers`, `handle_data` or `main`
"""
def func(request, response, environ, path):
def _main(req, resp):
pass
handler = FunctionHandler(_main)
if "main" in environ:
handler.func = environ["main"]
if "handle_headers" in environ:
handler.handle_headers = environ["handle_headers"]
if "handle_data" in environ:
handler.handle_data = environ["handle_data"]
if handler.func is _main and not hasattr(handler, "handle_headers") and not hasattr(handler, "handle_data"):
raise HTTPException(500, "No main function or handlers in script %s" % path)
return handler
return self._load_file(request, None, func)
python_script_handler = PythonScriptHandler()
class FunctionHandler(object):
def __init__(self, func):
self.func = func
def __call__(self, request, response):
try:
rv = self.func(request, response)
except HTTPException:
raise
except Exception:
msg = traceback.format_exc()
raise HTTPException(500, message=msg)
if rv is not None:
if isinstance(rv, tuple):
if len(rv) == 3:
status, headers, content = rv
response.status = status
elif len(rv) == 2:
headers, content = rv
else:
raise HTTPException(500)
response.headers.update(headers)
else:
content = rv
response.content = content
wrap_pipeline('', request, response)
# The generic name here is so that this can be used as a decorator
def handler(func):
return FunctionHandler(func)
class JsonHandler(object):
def __init__(self, func):
self.func = func
def __call__(self, request, response):
return FunctionHandler(self.handle_request)(request, response)
def handle_request(self, request, response):
rv = self.func(request, response)
response.headers.set("Content-Type", "application/json")
enc = json.dumps
if isinstance(rv, tuple):
rv = list(rv)
value = tuple(rv[:-1] + [enc(rv[-1])])
length = len(value[-1])
else:
value = enc(rv)
length = len(value)
response.headers.set("Content-Length", length)
return value
def json_handler(func):
return JsonHandler(func)
class AsIsHandler(object):
def __init__(self, base_path=None, url_base="/"):
self.base_path = base_path
self.url_base = url_base
def __call__(self, request, response):
path = filesystem_path(self.base_path, request, self.url_base)
try:
with open(path, 'rb') as f:
response.writer.write_raw_content(f.read())
wrap_pipeline(path, request, response)
response.close_connection = True
except IOError:
raise HTTPException(404)
as_is_handler = AsIsHandler()
class BasicAuthHandler(object):
def __init__(self, handler, user, password):
"""
A Basic Auth handler
:Args:
- handler: a secondary handler for the request after authentication is successful (example file_handler)
- user: string of the valid user name or None if any / all credentials are allowed
- password: string of the password required
"""
self.user = user
self.password = password
self.handler = handler
def __call__(self, request, response):
if "authorization" not in request.headers:
response.status = 401
response.headers.set("WWW-Authenticate", "Basic")
return response
else:
auth = Authentication(request.headers)
if self.user is not None and (self.user != auth.username or self.password != auth.password):
response.set_error(403, "Invalid username or password")
return response
return self.handler(request, response)
basic_auth_handler = BasicAuthHandler(file_handler, None, None)
class ErrorHandler(object):
def __init__(self, status):
self.status = status
def __call__(self, request, response):
response.set_error(self.status)
class StringHandler(object):
def __init__(self, data, content_type, **headers):
"""Handler that returns a fixed data string and headers
:param data: String to use
:param content_type: Content type header to server the response with
:param headers: List of headers to send with responses"""
self.data = data
self.resp_headers = [("Content-Type", content_type)]
for k, v in iteritems(headers):
self.resp_headers.append((k.replace("_", "-"), v))
self.handler = handler(self.handle_request)
def handle_request(self, request, response):
return self.resp_headers, self.data
def __call__(self, request, response):
rv = self.handler(request, response)
return rv
class StaticHandler(StringHandler):
def __init__(self, path, format_args, content_type, **headers):
"""Handler that reads a file from a path and substitutes some fixed data
Note that *.headers files have no effect in this handler.
:param path: Path to the template file to use
:param format_args: Dictionary of values to substitute into the template file
:param content_type: Content type header to server the response with
:param headers: List of headers to send with responses"""
with open(path) as f:
data = f.read()
if format_args:
data = data % format_args
return super(StaticHandler, self).__init__(data, content_type, **headers)

Просмотреть файл

@ -0,0 +1,29 @@
class NoOpLogger(object):
def critical(self, msg):
pass
def error(self, msg):
pass
def info(self, msg):
pass
def warning(self, msg):
pass
def debug(self, msg):
pass
logger = NoOpLogger()
_set_logger = False
def set_logger(new_logger):
global _set_logger
if _set_logger:
raise Exception("Logger must be set at most once")
global logger
logger = new_logger
_set_logger = True
def get_logger():
return logger

Просмотреть файл

@ -0,0 +1,559 @@
from collections import deque
import base64
import gzip as gzip_module
import hashlib
import os
import re
import time
import uuid
from six.moves import StringIO
from six import text_type, binary_type
try:
from html import escape
except ImportError:
from cgi import escape
def resolve_content(response):
return b"".join(item for item in response.iter_content(read_file=True))
class Pipeline(object):
pipes = {}
def __init__(self, pipe_string):
self.pipe_functions = self.parse(pipe_string)
def parse(self, pipe_string):
functions = []
for item in PipeTokenizer().tokenize(pipe_string):
if not item:
break
if item[0] == "function":
functions.append((self.pipes[item[1]], []))
elif item[0] == "argument":
functions[-1][1].append(item[1])
return functions
def __call__(self, request, response):
for func, args in self.pipe_functions:
response = func(request, response, *args)
return response
class PipeTokenizer(object):
def __init__(self):
#This whole class can likely be replaced by some regexps
self.state = None
def tokenize(self, string):
self.string = string
self.state = self.func_name_state
self._index = 0
while self.state:
yield self.state()
yield None
def get_char(self):
if self._index >= len(self.string):
return None
rv = self.string[self._index]
self._index += 1
return rv
def func_name_state(self):
rv = ""
while True:
char = self.get_char()
if char is None:
self.state = None
if rv:
return ("function", rv)
else:
return None
elif char == "(":
self.state = self.argument_state
return ("function", rv)
elif char == "|":
if rv:
return ("function", rv)
else:
rv += char
def argument_state(self):
rv = ""
while True:
char = self.get_char()
if char is None:
self.state = None
return ("argument", rv)
elif char == "\\":
rv += self.get_escape()
if rv is None:
#This should perhaps be an error instead
return ("argument", rv)
elif char == ",":
return ("argument", rv)
elif char == ")":
self.state = self.func_name_state
return ("argument", rv)
else:
rv += char
def get_escape(self):
char = self.get_char()
escapes = {"n": "\n",
"r": "\r",
"t": "\t"}
return escapes.get(char, char)
class pipe(object):
def __init__(self, *arg_converters):
self.arg_converters = arg_converters
self.max_args = len(self.arg_converters)
self.min_args = 0
opt_seen = False
for item in self.arg_converters:
if not opt_seen:
if isinstance(item, opt):
opt_seen = True
else:
self.min_args += 1
else:
if not isinstance(item, opt):
raise ValueError("Non-optional argument cannot follow optional argument")
def __call__(self, f):
def inner(request, response, *args):
if not (self.min_args <= len(args) <= self.max_args):
raise ValueError("Expected between %d and %d args, got %d" %
(self.min_args, self.max_args, len(args)))
arg_values = tuple(f(x) for f, x in zip(self.arg_converters, args))
return f(request, response, *arg_values)
Pipeline.pipes[f.__name__] = inner
#We actually want the undecorated function in the main namespace
return f
class opt(object):
def __init__(self, f):
self.f = f
def __call__(self, arg):
return self.f(arg)
def nullable(func):
def inner(arg):
if arg.lower() == "null":
return None
else:
return func(arg)
return inner
def boolean(arg):
if arg.lower() in ("true", "1"):
return True
elif arg.lower() in ("false", "0"):
return False
raise ValueError
@pipe(int)
def status(request, response, code):
"""Alter the status code.
:param code: Status code to use for the response."""
response.status = code
return response
@pipe(str, str, opt(boolean))
def header(request, response, name, value, append=False):
"""Set a HTTP header.
Replaces any existing HTTP header of the same name unless
append is set, in which case the header is appended without
replacement.
:param name: Name of the header to set.
:param value: Value to use for the header.
:param append: True if existing headers should not be replaced
"""
if not append:
response.headers.set(name, value)
else:
response.headers.append(name, value)
return response
@pipe(str)
def trickle(request, response, delays):
"""Send the response in parts, with time delays.
:param delays: A string of delays and amounts, in bytes, of the
response to send. Each component is separated by
a colon. Amounts in bytes are plain integers, whilst
delays are floats prefixed with a single d e.g.
d1:100:d2
Would cause a 1 second delay, would then send 100 bytes
of the file, and then cause a 2 second delay, before sending
the remainder of the file.
If the last token is of the form rN, instead of sending the
remainder of the file, the previous N instructions will be
repeated until the whole file has been sent e.g.
d1:100:d2:r2
Causes a delay of 1s, then 100 bytes to be sent, then a 2s delay
and then a further 100 bytes followed by a two second delay
until the response has been fully sent.
"""
def parse_delays():
parts = delays.split(":")
rv = []
for item in parts:
if item.startswith("d"):
item_type = "delay"
item = item[1:]
value = float(item)
elif item.startswith("r"):
item_type = "repeat"
value = int(item[1:])
if not value % 2 == 0:
raise ValueError
else:
item_type = "bytes"
value = int(item)
if len(rv) and rv[-1][0] == item_type:
rv[-1][1] += value
else:
rv.append((item_type, value))
return rv
delays = parse_delays()
if not delays:
return response
content = resolve_content(response)
offset = [0]
if not ("Cache-Control" in response.headers or
"Pragma" in response.headers or
"Expires" in response.headers):
response.headers.set("Cache-Control", "no-cache, no-store, must-revalidate")
response.headers.set("Pragma", "no-cache")
response.headers.set("Expires", "0")
def add_content(delays, repeat=False):
for i, (item_type, value) in enumerate(delays):
if item_type == "bytes":
yield content[offset[0]:offset[0] + value]
offset[0] += value
elif item_type == "delay":
time.sleep(value)
elif item_type == "repeat":
if i != len(delays) - 1:
continue
while offset[0] < len(content):
for item in add_content(delays[-(value + 1):-1], True):
yield item
if not repeat and offset[0] < len(content):
yield content[offset[0]:]
response.content = add_content(delays)
return response
@pipe(nullable(int), opt(nullable(int)))
def slice(request, response, start, end=None):
"""Send a byte range of the response body
:param start: The starting offset. Follows python semantics including
negative numbers.
:param end: The ending offset, again with python semantics and None
(spelled "null" in a query string) to indicate the end of
the file.
"""
content = resolve_content(response)[start:end]
response.content = content
response.headers.set("Content-Length", len(content))
return response
class ReplacementTokenizer(object):
def arguments(self, token):
unwrapped = token[1:-1].decode('utf8')
return ("arguments", re.split(r",\s*", unwrapped) if unwrapped else [])
def ident(self, token):
return ("ident", token.decode('utf8'))
def index(self, token):
token = token[1:-1].decode('utf8')
try:
index = int(token)
except ValueError:
index = token
return ("index", index)
def var(self, token):
token = token[:-1].decode('utf8')
return ("var", token)
def tokenize(self, string):
assert isinstance(string, binary_type)
return self.scanner.scan(string)[0]
scanner = re.Scanner([(br"\$\w+:", var),
(br"\$?\w+", ident),
(br"\[[^\]]*\]", index),
(br"\([^)]*\)", arguments)])
class FirstWrapper(object):
def __init__(self, params):
self.params = params
def __getitem__(self, key):
try:
if isinstance(key, text_type):
key = key.encode('iso-8859-1')
return self.params.first(key)
except KeyError:
return ""
@pipe(opt(nullable(str)))
def sub(request, response, escape_type="html"):
"""Substitute environment information about the server and request into the script.
:param escape_type: String detailing the type of escaping to use. Known values are
"html" and "none", with "html" the default for historic reasons.
The format is a very limited template language. Substitutions are
enclosed by {{ and }}. There are several available substitutions:
host
A simple string value and represents the primary host from which the
tests are being run.
domains
A dictionary of available domains indexed by subdomain name.
ports
A dictionary of lists of ports indexed by protocol.
location
A dictionary of parts of the request URL. Valid keys are
'server, 'scheme', 'host', 'hostname', 'port', 'path' and 'query'.
'server' is scheme://host:port, 'host' is hostname:port, and query
includes the leading '?', but other delimiters are omitted.
headers
A dictionary of HTTP headers in the request.
header_or_default(header, default)
The value of an HTTP header, or a default value if it is absent.
For example::
{{header_or_default(X-Test, test-header-absent)}}
GET
A dictionary of query parameters supplied with the request.
uuid()
A pesudo-random UUID suitable for usage with stash
file_hash(algorithm, filepath)
The cryptographic hash of a file. Supported algorithms: md5, sha1,
sha224, sha256, sha384, and sha512. For example::
{{file_hash(md5, dom/interfaces.html)}}
fs_path(filepath)
The absolute path to a file inside the wpt document root
So for example in a setup running on localhost with a www
subdomain and a http server on ports 80 and 81::
{{host}} => localhost
{{domains[www]}} => www.localhost
{{ports[http][1]}} => 81
It is also possible to assign a value to a variable name, which must start
with the $ character, using the ":" syntax e.g.::
{{$id:uuid()}}
Later substitutions in the same file may then refer to the variable
by name e.g.::
{{$id}}
"""
content = resolve_content(response)
new_content = template(request, content, escape_type=escape_type)
response.content = new_content
return response
class SubFunctions(object):
@staticmethod
def uuid(request):
return str(uuid.uuid4())
# Maintain a list of supported algorithms, restricted to those that are
# available on all platforms [1]. This ensures that test authors do not
# unknowingly introduce platform-specific tests.
#
# [1] https://docs.python.org/2/library/hashlib.html
supported_algorithms = ("md5", "sha1", "sha224", "sha256", "sha384", "sha512")
@staticmethod
def file_hash(request, algorithm, path):
assert isinstance(algorithm, text_type)
if algorithm not in SubFunctions.supported_algorithms:
raise ValueError("Unsupported encryption algorithm: '%s'" % algorithm)
hash_obj = getattr(hashlib, algorithm)()
absolute_path = os.path.join(request.doc_root, path)
try:
with open(absolute_path, "rb") as f:
hash_obj.update(f.read())
except IOError:
# In this context, an unhandled IOError will be interpreted by the
# server as an indication that the template file is non-existent.
# Although the generic "Exception" is less precise, it avoids
# triggering a potentially-confusing HTTP 404 error in cases where
# the path to the file to be hashed is invalid.
raise Exception('Cannot open file for hash computation: "%s"' % absolute_path)
return base64.b64encode(hash_obj.digest()).strip()
@staticmethod
def fs_path(request, path):
if not path.startswith("/"):
subdir = request.request_path[len(request.url_base):]
if "/" in subdir:
subdir = subdir.rsplit("/", 1)[0]
root_rel_path = subdir + "/" + path
else:
root_rel_path = path[1:]
root_rel_path = root_rel_path.replace("/", os.path.sep)
absolute_path = os.path.abspath(os.path.join(request.doc_root, root_rel_path))
if ".." in os.path.relpath(absolute_path, request.doc_root):
raise ValueError("Path outside wpt root")
return absolute_path
@staticmethod
def header_or_default(request, name, default):
return request.headers.get(name, default)
def template(request, content, escape_type="html"):
#TODO: There basically isn't any error handling here
tokenizer = ReplacementTokenizer()
variables = {}
def config_replacement(match):
content, = match.groups()
tokens = tokenizer.tokenize(content)
tokens = deque(tokens)
token_type, field = tokens.popleft()
assert isinstance(field, text_type)
if token_type == "var":
variable = field
token_type, field = tokens.popleft()
assert isinstance(field, text_type)
else:
variable = None
if token_type != "ident":
raise Exception("unexpected token type %s (token '%r'), expected ident" % (token_type, field))
if field in variables:
value = variables[field]
elif hasattr(SubFunctions, field):
value = getattr(SubFunctions, field)
elif field == "headers":
value = request.headers
elif field == "GET":
value = FirstWrapper(request.GET)
elif field == "hosts":
value = request.server.config.all_domains
elif field == "domains":
value = request.server.config.all_domains[""]
elif field == "host":
value = request.server.config["browser_host"]
elif field in request.server.config:
value = request.server.config[field]
elif field == "location":
value = {"server": "%s://%s:%s" % (request.url_parts.scheme,
request.url_parts.hostname,
request.url_parts.port),
"scheme": request.url_parts.scheme,
"host": "%s:%s" % (request.url_parts.hostname,
request.url_parts.port),
"hostname": request.url_parts.hostname,
"port": request.url_parts.port,
"path": request.url_parts.path,
"pathname": request.url_parts.path,
"query": "?%s" % request.url_parts.query}
elif field == "url_base":
value = request.url_base
else:
raise Exception("Undefined template variable %s" % field)
while tokens:
ttype, field = tokens.popleft()
if ttype == "index":
value = value[field]
elif ttype == "arguments":
value = value(request, *field)
else:
raise Exception(
"unexpected token type %s (token '%r'), expected ident or arguments" % (ttype, field)
)
assert isinstance(value, (int, (binary_type, text_type))), tokens
if variable is not None:
variables[variable] = value
escape_func = {"html": lambda x:escape(x, quote=True),
"none": lambda x:x}[escape_type]
# Should possibly support escaping for other contexts e.g. script
# TODO: read the encoding of the response
# cgi.escape() only takes text strings in Python 3.
if isinstance(value, binary_type):
value = value.decode("utf-8")
elif isinstance(value, int):
value = text_type(value)
return escape_func(value).encode("utf-8")
template_regexp = re.compile(br"{{([^}]*)}}")
new_content = template_regexp.sub(config_replacement, content)
return new_content
@pipe()
def gzip(request, response):
"""This pipe gzip-encodes response data.
It sets (or overwrites) these HTTP headers:
Content-Encoding is set to gzip
Content-Length is set to the length of the compressed content
"""
content = resolve_content(response)
response.headers.set("Content-Encoding", "gzip")
out = StringIO()
with gzip_module.GzipFile(fileobj=out, mode="w") as f:
f.write(content)
response.content = out.getvalue()
response.headers.set("Content-Length", len(response.content))
return response

Просмотреть файл

@ -0,0 +1,94 @@
from .utils import HTTPException
class RangeParser(object):
def __call__(self, header, file_size):
try:
header = header.decode("ascii")
except UnicodeDecodeError:
raise HTTPException(400, "Non-ASCII range header value")
prefix = "bytes="
if not header.startswith(prefix):
raise HTTPException(416, message="Unrecognised range type %s" % (header,))
parts = header[len(prefix):].split(",")
ranges = []
for item in parts:
components = item.split("-")
if len(components) != 2:
raise HTTPException(416, "Bad range specifier %s" % (item))
data = []
for component in components:
if component == "":
data.append(None)
else:
try:
data.append(int(component))
except ValueError:
raise HTTPException(416, "Bad range specifier %s" % (item))
try:
ranges.append(Range(data[0], data[1], file_size))
except ValueError:
raise HTTPException(416, "Bad range specifier %s" % (item))
return self.coalesce_ranges(ranges, file_size)
def coalesce_ranges(self, ranges, file_size):
rv = []
target = None
for current in reversed(sorted(ranges)):
if target is None:
target = current
else:
new = target.coalesce(current)
target = new[0]
if len(new) > 1:
rv.append(new[1])
rv.append(target)
return rv[::-1]
class Range(object):
def __init__(self, lower, upper, file_size):
self.file_size = file_size
self.lower, self.upper = self._abs(lower, upper)
if self.lower >= self.upper or self.lower >= self.file_size:
raise ValueError
def __repr__(self):
return "<Range %s-%s>" % (self.lower, self.upper)
def __lt__(self, other):
return self.lower < other.lower
def __gt__(self, other):
return self.lower > other.lower
def __eq__(self, other):
return self.lower == other.lower and self.upper == other.upper
def _abs(self, lower, upper):
if lower is None and upper is None:
lower, upper = 0, self.file_size
elif lower is None:
lower, upper = max(0, self.file_size - upper), self.file_size
elif upper is None:
lower, upper = lower, self.file_size
else:
lower, upper = lower, min(self.file_size, upper + 1)
return lower, upper
def coalesce(self, other):
assert self.file_size == other.file_size
if (self.upper < other.lower or self.lower > other.upper):
return sorted([self, other])
else:
return [Range(min(self.lower, other.lower),
max(self.upper, other.upper) - 1,
self.file_size)]
def header_value(self):
return "bytes %i-%i/%i" % (self.lower, self.upper - 1, self.file_size)

Просмотреть файл

@ -0,0 +1,688 @@
import base64
import cgi
import tempfile
from six import BytesIO, binary_type, iteritems, PY3
from six.moves.http_cookies import BaseCookie
from six.moves.urllib.parse import parse_qsl, urlsplit
from . import stash
from .utils import HTTPException, isomorphic_encode, isomorphic_decode
missing = object()
class Server(object):
"""Data about the server environment
.. attribute:: config
Environment configuration information with information about the
various servers running, their hostnames and ports.
.. attribute:: stash
Stash object holding state stored on the server between requests.
"""
config = None
def __init__(self, request):
self._stash = None
self._request = request
@property
def stash(self):
if self._stash is None:
address, authkey = stash.load_env_config()
self._stash = stash.Stash(self._request.url_parts.path, address, authkey)
return self._stash
class InputFile(object):
max_buffer_size = 1024*1024
def __init__(self, rfile, length):
"""File-like object used to provide a seekable view of request body data"""
self._file = rfile
self.length = length
self._file_position = 0
if length > self.max_buffer_size:
self._buf = tempfile.TemporaryFile()
else:
self._buf = BytesIO()
@property
def _buf_position(self):
rv = self._buf.tell()
assert rv <= self._file_position
return rv
def read(self, bytes=-1):
assert self._buf_position <= self._file_position
if bytes < 0:
bytes = self.length - self._buf_position
bytes_remaining = min(bytes, self.length - self._buf_position)
if bytes_remaining == 0:
return b""
if self._buf_position != self._file_position:
buf_bytes = min(bytes_remaining, self._file_position - self._buf_position)
old_data = self._buf.read(buf_bytes)
bytes_remaining -= buf_bytes
else:
old_data = b""
assert bytes_remaining == 0 or self._buf_position == self._file_position, (
"Before reading buffer position (%i) didn't match file position (%i)" %
(self._buf_position, self._file_position))
new_data = self._file.read(bytes_remaining)
self._buf.write(new_data)
self._file_position += bytes_remaining
assert bytes_remaining == 0 or self._buf_position == self._file_position, (
"After reading buffer position (%i) didn't match file position (%i)" %
(self._buf_position, self._file_position))
return old_data + new_data
def tell(self):
return self._buf_position
def seek(self, offset):
if offset > self.length or offset < 0:
raise ValueError
if offset <= self._file_position:
self._buf.seek(offset)
else:
self.read(offset - self._file_position)
def readline(self, max_bytes=None):
if max_bytes is None:
max_bytes = self.length - self._buf_position
if self._buf_position < self._file_position:
data = self._buf.readline(max_bytes)
if data.endswith(b"\n") or len(data) == max_bytes:
return data
else:
data = b""
assert self._buf_position == self._file_position
initial_position = self._file_position
found = False
buf = []
max_bytes -= len(data)
while not found:
readahead = self.read(min(2, max_bytes))
max_bytes -= len(readahead)
for i, c in enumerate(readahead):
if c == b"\n"[0]:
buf.append(readahead[:i+1])
found = True
break
if not found:
buf.append(readahead)
if not readahead or not max_bytes:
break
new_data = b"".join(buf)
data += new_data
self.seek(initial_position + len(new_data))
return data
def readlines(self):
rv = []
while True:
data = self.readline()
if data:
rv.append(data)
else:
break
return rv
def __next__(self):
data = self.readline()
if data:
return data
else:
raise StopIteration
next = __next__
def __iter__(self):
return self
class Request(object):
"""Object representing a HTTP request.
.. attribute:: doc_root
The local directory to use as a base when resolving paths
.. attribute:: route_match
Regexp match object from matching the request path to the route
selected for the request.
.. attribute:: client_address
Contains a tuple of the form (host, port) representing the client's address.
.. attribute:: protocol_version
HTTP version specified in the request.
.. attribute:: method
HTTP method in the request.
.. attribute:: request_path
Request path as it appears in the HTTP request.
.. attribute:: url_base
The prefix part of the path; typically / unless the handler has a url_base set
.. attribute:: url
Absolute URL for the request.
.. attribute:: url_parts
Parts of the requested URL as obtained by urlparse.urlsplit(path)
.. attribute:: request_line
Raw request line
.. attribute:: headers
RequestHeaders object providing a dictionary-like representation of
the request headers.
.. attribute:: raw_headers.
Dictionary of non-normalized request headers.
.. attribute:: body
Request body as a string
.. attribute:: raw_input
File-like object representing the body of the request.
.. attribute:: GET
MultiDict representing the parameters supplied with the request.
Note that these may be present on non-GET requests; the name is
chosen to be familiar to users of other systems such as PHP.
Both keys and values are binary strings.
.. attribute:: POST
MultiDict representing the request body parameters. Most parameters
are present as string values, but file uploads have file-like
values. All string values (including keys) have binary type.
.. attribute:: cookies
A Cookies object representing cookies sent with the request with a
dictionary-like interface.
.. attribute:: auth
An instance of Authentication with username and password properties
representing any credentials supplied using HTTP authentication.
.. attribute:: server
Server object containing information about the server environment.
"""
def __init__(self, request_handler):
self.doc_root = request_handler.server.router.doc_root
self.route_match = None # Set by the router
self.client_address = request_handler.client_address
self.protocol_version = request_handler.protocol_version
self.method = request_handler.command
# Keys and values in raw headers are native strings.
self._headers = None
self.raw_headers = request_handler.headers
scheme = request_handler.server.scheme
host = self.raw_headers.get("Host")
port = request_handler.server.server_address[1]
if host is None:
host = request_handler.server.server_address[0]
else:
if ":" in host:
host, port = host.split(":", 1)
self.request_path = request_handler.path
self.url_base = "/"
if self.request_path.startswith(scheme + "://"):
self.url = self.request_path
else:
# TODO(#23362): Stop using native strings for URLs.
self.url = "%s://%s:%s%s" % (
scheme, host, port, self.request_path)
self.url_parts = urlsplit(self.url)
self.request_line = request_handler.raw_requestline
self.raw_input = InputFile(request_handler.rfile,
int(self.raw_headers.get("Content-Length", 0)))
self._body = None
self._GET = None
self._POST = None
self._cookies = None
self._auth = None
self.server = Server(self)
def __repr__(self):
return "<Request %s %s>" % (self.method, self.url)
@property
def GET(self):
if self._GET is None:
kwargs = {
"keep_blank_values": True,
}
if PY3:
kwargs["encoding"] = "iso-8859-1"
params = parse_qsl(self.url_parts.query, **kwargs)
self._GET = MultiDict()
for key, value in params:
self._GET.add(isomorphic_encode(key), isomorphic_encode(value))
return self._GET
@property
def POST(self):
if self._POST is None:
# Work out the post parameters
pos = self.raw_input.tell()
self.raw_input.seek(0)
kwargs = {
"fp": self.raw_input,
"environ": {"REQUEST_METHOD": self.method},
"headers": self.raw_headers,
"keep_blank_values": True,
}
if PY3:
kwargs["encoding"] = "iso-8859-1"
fs = cgi.FieldStorage(**kwargs)
self._POST = MultiDict.from_field_storage(fs)
self.raw_input.seek(pos)
return self._POST
@property
def cookies(self):
if self._cookies is None:
parser = BinaryCookieParser()
cookie_headers = self.headers.get("cookie", b"")
parser.load(cookie_headers)
cookies = Cookies()
for key, value in iteritems(parser):
cookies[isomorphic_encode(key)] = CookieValue(value)
self._cookies = cookies
return self._cookies
@property
def headers(self):
if self._headers is None:
self._headers = RequestHeaders(self.raw_headers)
return self._headers
@property
def body(self):
if self._body is None:
pos = self.raw_input.tell()
self.raw_input.seek(0)
self._body = self.raw_input.read()
self.raw_input.seek(pos)
return self._body
@property
def auth(self):
if self._auth is None:
self._auth = Authentication(self.headers)
return self._auth
class H2Request(Request):
def __init__(self, request_handler):
self.h2_stream_id = request_handler.h2_stream_id
self.frames = []
super(H2Request, self).__init__(request_handler)
class RequestHeaders(dict):
"""Read-only dictionary-like API for accessing request headers.
Unlike BaseHTTPRequestHandler.headers, this class always returns all
headers with the same name (separated by commas). And it ensures all keys
(i.e. names of headers) and values have binary type.
"""
def __init__(self, items):
for header in items.keys():
key = isomorphic_encode(header).lower()
# get all headers with the same name
values = items.getallmatchingheaders(header)
if len(values) > 1:
# collect the multiple variations of the current header
multiples = []
# loop through the values from getallmatchingheaders
for value in values:
# getallmatchingheaders returns raw header lines, so
# split to get name, value
multiples.append(isomorphic_encode(value).split(b':', 1)[1].strip())
headers = multiples
else:
headers = [isomorphic_encode(items[header])]
dict.__setitem__(self, key, headers)
def __getitem__(self, key):
"""Get all headers of a certain (case-insensitive) name. If there is
more than one, the values are returned comma separated"""
key = isomorphic_encode(key)
values = dict.__getitem__(self, key.lower())
if len(values) == 1:
return values[0]
else:
return b", ".join(values)
def __setitem__(self, name, value):
raise Exception
def get(self, key, default=None):
"""Get a string representing all headers with a particular value,
with multiple headers separated by a comma. If no header is found
return a default value
:param key: The header name to look up (case-insensitive)
:param default: The value to return in the case of no match
"""
try:
return self[key]
except KeyError:
return default
def get_list(self, key, default=missing):
"""Get all the header values for a particular field name as
a list"""
key = isomorphic_encode(key)
try:
return dict.__getitem__(self, key.lower())
except KeyError:
if default is not missing:
return default
else:
raise
def __contains__(self, key):
key = isomorphic_encode(key)
return dict.__contains__(self, key.lower())
def iteritems(self):
for item in self:
yield item, self[item]
def itervalues(self):
for item in self:
yield self[item]
class CookieValue(object):
"""Representation of cookies.
Note that cookies are considered read-only and the string value
of the cookie will not change if you update the field values.
However this is not enforced.
.. attribute:: key
The name of the cookie.
.. attribute:: value
The value of the cookie
.. attribute:: expires
The expiry date of the cookie
.. attribute:: path
The path of the cookie
.. attribute:: comment
The comment of the cookie.
.. attribute:: domain
The domain with which the cookie is associated
.. attribute:: max_age
The max-age value of the cookie.
.. attribute:: secure
Whether the cookie is marked as secure
.. attribute:: httponly
Whether the cookie is marked as httponly
"""
def __init__(self, morsel):
self.key = morsel.key
self.value = morsel.value
for attr in ["expires", "path",
"comment", "domain", "max-age",
"secure", "version", "httponly"]:
setattr(self, attr.replace("-", "_"), morsel[attr])
self._str = morsel.OutputString()
def __str__(self):
return self._str
def __repr__(self):
return self._str
def __eq__(self, other):
"""Equality comparison for cookies. Compares to other cookies
based on value alone and on non-cookies based on the equality
of self.value with the other object so that a cookie with value
"ham" compares equal to the string "ham"
"""
if hasattr(other, "value"):
return self.value == other.value
return self.value == other
class MultiDict(dict):
"""Dictionary type that holds multiple values for each key"""
# TODO: this should perhaps also order the keys
def __init__(self):
pass
def __setitem__(self, name, value):
dict.__setitem__(self, name, [value])
def add(self, name, value):
if name in self:
dict.__getitem__(self, name).append(value)
else:
dict.__setitem__(self, name, [value])
def __getitem__(self, key):
"""Get the first value with a given key"""
return self.first(key)
def first(self, key, default=missing):
"""Get the first value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (throws if nothing is
specified)
"""
if key in self and dict.__getitem__(self, key):
return dict.__getitem__(self, key)[0]
elif default is not missing:
return default
raise KeyError(key)
def last(self, key, default=missing):
"""Get the last value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (throws if nothing is
specified)
"""
if key in self and dict.__getitem__(self, key):
return dict.__getitem__(self, key)[-1]
elif default is not missing:
return default
raise KeyError(key)
# We need to explicitly override dict.get; otherwise, it won't call
# __getitem__ and would return a list instead.
def get(self, key, default=None):
"""Get the first value with a given key
:param key: The key to lookup
:param default: The default to return if key is
not found (None by default)
"""
return self.first(key, default)
def get_list(self, key):
"""Get all values with a given key as a list
:param key: The key to lookup
"""
if key in self:
return dict.__getitem__(self, key)
else:
return []
@classmethod
def from_field_storage(cls, fs):
"""Construct a MultiDict from a cgi.FieldStorage
Note that all keys and values are binary strings.
"""
self = cls()
if fs.list is None:
return self
for key in fs:
values = fs[key]
if not isinstance(values, list):
values = [values]
for value in values:
if not value.filename:
value = isomorphic_encode(value.value)
else:
assert isinstance(value, cgi.FieldStorage)
self.add(isomorphic_encode(key), value)
return self
class BinaryCookieParser(BaseCookie):
"""A subclass of BaseCookie that returns values in binary strings
This is not intended to store the cookies; use Cookies instead.
"""
def value_decode(self, val):
"""Decode value from network to (real_value, coded_value).
Override BaseCookie.value_decode.
"""
return isomorphic_encode(val), val
def value_encode(self, val):
raise NotImplementedError('BinaryCookieParser is not for setting cookies')
def load(self, rawdata):
"""Load cookies from a binary string.
This overrides and calls BaseCookie.load. Unlike BaseCookie.load, it
does not accept dictionaries.
"""
assert isinstance(rawdata, binary_type)
if PY3:
# BaseCookie.load expects a native string, which in Python 3 is text.
rawdata = isomorphic_decode(rawdata)
super(BinaryCookieParser, self).load(rawdata)
class Cookies(MultiDict):
"""MultiDict specialised for Cookie values
Keys and values are binary strings.
"""
def __init__(self):
pass
def __getitem__(self, key):
return self.last(key)
class Authentication(object):
"""Object for dealing with HTTP Authentication
.. attribute:: username
The username supplied in the HTTP Authorization
header, or None
.. attribute:: password
The password supplied in the HTTP Authorization
header, or None
Both attributes are binary strings (`str` in Py2, `bytes` in Py3), since
RFC7617 Section 2.1 does not specify the encoding for username & password
(as long it's compatible with ASCII). UTF-8 should be a relatively safe
choice if callers need to decode them as most browsers use it.
"""
def __init__(self, headers):
self.username = None
self.password = None
auth_schemes = {b"Basic": self.decode_basic}
if "authorization" in headers:
header = headers.get("authorization")
assert isinstance(header, binary_type)
auth_type, data = header.split(b" ", 1)
if auth_type in auth_schemes:
self.username, self.password = auth_schemes[auth_type](data)
else:
raise HTTPException(400, "Unsupported authentication scheme %s" % auth_type)
def decode_basic(self, data):
assert isinstance(data, binary_type)
decoded_data = base64.b64decode(data)
return decoded_data.split(b":", 1)

Просмотреть файл

@ -0,0 +1,813 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from io import BytesIO
import json
import socket
import uuid
from hpack.struct import HeaderTuple
from hyperframe.frame import HeadersFrame, DataFrame, ContinuationFrame
from six import binary_type, text_type, integer_types, itervalues, PY3
from six.moves.http_cookies import BaseCookie, Morsel
from .constants import response_codes, h2_headers
from .logger import get_logger
from .utils import isomorphic_decode, isomorphic_encode
missing = object()
class Response(object):
"""Object representing the response to a HTTP request
:param handler: RequestHandler being used for this response
:param request: Request that this is the response for
.. attribute:: request
Request associated with this Response.
.. attribute:: encoding
The encoding to use when converting unicode to strings for output.
.. attribute:: add_required_headers
Boolean indicating whether mandatory headers should be added to the
response.
.. attribute:: send_body_for_head_request
Boolean, default False, indicating whether the body content should be
sent when the request method is HEAD.
.. attribute:: writer
The ResponseWriter for this response
.. attribute:: status
Status tuple (code, message). Can be set to an integer in which case the
message part is filled in automatically, or a tuple (code, message) in
which case code is an int and message is a text or binary string.
.. attribute:: headers
List of HTTP headers to send with the response. Each item in the list is a
tuple of (name, value).
.. attribute:: content
The body of the response. This can either be a string or a iterable of response
parts. If it is an iterable, any item may be a string or a function of zero
parameters which, when called, returns a string."""
def __init__(self, handler, request, response_writer_cls=None):
self.request = request
self.encoding = "utf8"
self.add_required_headers = True
self.send_body_for_head_request = False
self.close_connection = False
self.logger = get_logger()
self.writer = response_writer_cls(handler, self) if response_writer_cls else ResponseWriter(handler, self)
self._status = (200, None)
self.headers = ResponseHeaders()
self.content = []
@property
def status(self):
return self._status
@status.setter
def status(self, value):
if hasattr(value, "__len__"):
if len(value) != 2:
raise ValueError
else:
code = int(value[0])
message = value[1]
# Only call str() if message is not a string type, so that we
# don't get `str(b"foo") == "b'foo'"` in Python 3.
if not isinstance(message, (binary_type, text_type)):
message = str(message)
self._status = (code, message)
else:
self._status = (int(value), None)
def set_cookie(self, name, value, path="/", domain=None, max_age=None,
expires=None, secure=False, httponly=False, comment=None):
"""Set a cookie to be sent with a Set-Cookie header in the
response
:param name: name of the cookie (a binary string)
:param value: value of the cookie (a binary string, or None)
:param max_age: datetime.timedelta int representing the time (in seconds)
until the cookie expires
:param path: String path to which the cookie applies
:param domain: String domain to which the cookie applies
:param secure: Boolean indicating whether the cookie is marked as secure
:param httponly: Boolean indicating whether the cookie is marked as
HTTP Only
:param comment: String comment
:param expires: datetime.datetime or datetime.timedelta indicating a
time or interval from now when the cookie expires
"""
# TODO(Python 3): Convert other parameters (e.g. path) to bytes, too.
if value is None:
value = b''
max_age = 0
expires = timedelta(days=-1)
if PY3:
name = isomorphic_decode(name)
value = isomorphic_decode(value)
days = {i+1: name for i, name in enumerate(["jan", "feb", "mar",
"apr", "may", "jun",
"jul", "aug", "sep",
"oct", "nov", "dec"])}
if isinstance(expires, timedelta):
expires = datetime.utcnow() + expires
if expires is not None:
expires_str = expires.strftime("%d %%s %Y %H:%M:%S GMT")
expires_str = expires_str % days[expires.month]
expires = expires_str
if max_age is not None:
if hasattr(max_age, "total_seconds"):
max_age = int(max_age.total_seconds())
max_age = "%.0d" % max_age
m = Morsel()
def maybe_set(key, value):
if value is not None and value is not False:
m[key] = value
m.set(name, value, value)
maybe_set("path", path)
maybe_set("domain", domain)
maybe_set("comment", comment)
maybe_set("expires", expires)
maybe_set("max-age", max_age)
maybe_set("secure", secure)
maybe_set("httponly", httponly)
self.headers.append("Set-Cookie", m.OutputString())
def unset_cookie(self, name):
"""Remove a cookie from those that are being sent with the response"""
if PY3:
name = isomorphic_decode(name)
cookies = self.headers.get("Set-Cookie")
parser = BaseCookie()
for cookie in cookies:
if PY3:
# BaseCookie.load expects a text string.
cookie = isomorphic_decode(cookie)
parser.load(cookie)
if name in parser.keys():
del self.headers["Set-Cookie"]
for m in parser.values():
if m.key != name:
self.headers.append(("Set-Cookie", m.OutputString()))
def delete_cookie(self, name, path="/", domain=None):
"""Delete a cookie on the client by setting it to the empty string
and to expire in the past"""
self.set_cookie(name, None, path=path, domain=domain, max_age=0,
expires=timedelta(days=-1))
def iter_content(self, read_file=False):
"""Iterator returning chunks of response body content.
If any part of the content is a function, this will be called
and the resulting value (if any) returned.
:param read_file: boolean controlling the behaviour when content is a
file handle. When set to False the handle will be
returned directly allowing the file to be passed to
the output in small chunks. When set to True, the
entire content of the file will be returned as a
string facilitating non-streaming operations like
template substitution.
"""
if isinstance(self.content, binary_type):
yield self.content
elif isinstance(self.content, text_type):
yield self.content.encode(self.encoding)
elif hasattr(self.content, "read"):
if read_file:
yield self.content.read()
else:
yield self.content
else:
for item in self.content:
if hasattr(item, "__call__"):
value = item()
else:
value = item
if value:
yield value
def write_status_headers(self):
"""Write out the status line and headers for the response"""
self.writer.write_status(*self.status)
for item in self.headers:
self.writer.write_header(*item)
self.writer.end_headers()
def write_content(self):
"""Write out the response content"""
if self.request.method != "HEAD" or self.send_body_for_head_request:
for item in self.iter_content():
self.writer.write_content(item)
def write(self):
"""Write the whole response"""
self.write_status_headers()
self.write_content()
def set_error(self, code, message=u""):
"""Set the response status headers and return a JSON error object:
{"error": {"code": code, "message": message}}
code is an int (HTTP status code), and message is a text string.
"""
err = {"code": code,
"message": message}
data = json.dumps({"error": err})
self.status = code
self.headers = [("Content-Type", "application/json"),
("Content-Length", len(data))]
self.content = data
if code == 500:
self.logger.error(message)
class MultipartContent(object):
def __init__(self, boundary=None, default_content_type=None):
self.items = []
if boundary is None:
boundary = text_type(uuid.uuid4())
self.boundary = boundary
self.default_content_type = default_content_type
def __call__(self):
boundary = b"--" + self.boundary.encode("ascii")
rv = [b"", boundary]
for item in self.items:
rv.append(item.to_bytes())
rv.append(boundary)
rv[-1] += b"--"
return b"\r\n".join(rv)
def append_part(self, data, content_type=None, headers=None):
if content_type is None:
content_type = self.default_content_type
self.items.append(MultipartPart(data, content_type, headers))
def __iter__(self):
#This is hackish; when writing the response we need an iterable
#or a string. For a multipart/byterange response we want an
#iterable that contains a single callable; the MultipartContent
#object itself
yield self
class MultipartPart(object):
def __init__(self, data, content_type=None, headers=None):
assert isinstance(data, binary_type), data
self.headers = ResponseHeaders()
if content_type is not None:
self.headers.set("Content-Type", content_type)
if headers is not None:
for name, value in headers:
if name.lower() == b"content-type":
func = self.headers.set
else:
func = self.headers.append
func(name, value)
self.data = data
def to_bytes(self):
rv = []
for key, value in self.headers:
assert isinstance(key, binary_type)
assert isinstance(value, binary_type)
rv.append(b"%s: %s" % (key, value))
rv.append(b"")
rv.append(self.data)
return b"\r\n".join(rv)
def _maybe_encode(s):
"""Encode a string or an int into binary data using isomorphic_encode()."""
if isinstance(s, integer_types):
return b"%i" % (s,)
return isomorphic_encode(s)
class ResponseHeaders(object):
"""Dictionary-like object holding the headers for the response"""
def __init__(self):
self.data = OrderedDict()
def set(self, key, value):
"""Set a header to a specific value, overwriting any previous header
with the same name
:param key: Name of the header to set
:param value: Value to set the header to
"""
key = _maybe_encode(key)
value = _maybe_encode(value)
self.data[key.lower()] = (key, [value])
def append(self, key, value):
"""Add a new header with a given name, not overwriting any existing
headers with the same name
:param key: Name of the header to add
:param value: Value to set for the header
"""
key = _maybe_encode(key)
value = _maybe_encode(value)
if key.lower() in self.data:
self.data[key.lower()][1].append(value)
else:
self.set(key, value)
def get(self, key, default=missing):
"""Get the set values for a particular header."""
key = _maybe_encode(key)
try:
return self[key]
except KeyError:
if default is missing:
return []
return default
def __getitem__(self, key):
"""Get a list of values for a particular header
"""
key = _maybe_encode(key)
return self.data[key.lower()][1]
def __delitem__(self, key):
key = _maybe_encode(key)
del self.data[key.lower()]
def __contains__(self, key):
key = _maybe_encode(key)
return key.lower() in self.data
def __setitem__(self, key, value):
self.set(key, value)
def __iter__(self):
for key, values in itervalues(self.data):
for value in values:
yield key, value
def items(self):
return list(self)
def update(self, items_iter):
for name, value in items_iter:
self.append(name, value)
def __repr__(self):
return repr(self.data)
class H2Response(Response):
def __init__(self, handler, request):
super(H2Response, self).__init__(handler, request, response_writer_cls=H2ResponseWriter)
def write_status_headers(self):
self.writer.write_headers(self.headers, *self.status)
# Hacky way of detecting last item in generator
def write_content(self):
"""Write out the response content"""
if self.request.method != "HEAD" or self.send_body_for_head_request:
item = None
item_iter = self.iter_content()
try:
item = next(item_iter)
while True:
check_last = next(item_iter)
self.writer.write_data(item, last=False)
item = check_last
except StopIteration:
if item:
self.writer.write_data(item, last=True)
class H2ResponseWriter(object):
def __init__(self, handler, response):
self.socket = handler.request
self.h2conn = handler.conn
self._response = response
self._handler = handler
self.stream_ended = False
self.content_written = False
self.request = response.request
self.logger = response.logger
def write_headers(self, headers, status_code, status_message=None, stream_id=None, last=False):
"""
Send a HEADER frame that is tracked by the local state machine.
Write a HEADER frame using the H2 Connection object, will only work if the stream is in a state to send
HEADER frames.
:param headers: List of (header, value) tuples
:param status_code: The HTTP status code of the response
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param last: Flag to signal if this is the last frame in stream.
"""
formatted_headers = []
secondary_headers = [] # Non ':' prefixed headers are to be added afterwards
for header, value in headers:
# h2_headers are native strings
# header field names are strings of ASCII
if isinstance(header, binary_type):
header = header.decode('ascii')
# value in headers can be either string or integer
if isinstance(value, binary_type):
value = self.decode(value)
if header in h2_headers:
header = ':' + header
formatted_headers.append((header, str(value)))
else:
secondary_headers.append((header, str(value)))
formatted_headers.append((':status', str(status_code)))
formatted_headers.extend(secondary_headers)
with self.h2conn as connection:
connection.send_headers(
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
headers=formatted_headers,
end_stream=last or self.request.method == "HEAD"
)
self.write(connection)
def write_data(self, item, last=False, stream_id=None):
"""
Send a DATA frame that is tracked by the local state machine.
Write a DATA frame using the H2 Connection object, will only work if the stream is in a state to send
DATA frames. Uses flow control to split data into multiple data frames if it exceeds the size that can
be in a single frame.
:param item: The content of the DATA frame
:param last: Flag to signal if this is the last frame in stream.
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
"""
if isinstance(item, (text_type, binary_type)):
data = BytesIO(self.encode(item))
else:
data = item
# Find the length of the data
data.seek(0, 2)
data_len = data.tell()
data.seek(0)
# If the data is longer than max payload size, need to write it in chunks
payload_size = self.get_max_payload_size()
while data_len > payload_size:
self.write_data_frame(data.read(payload_size), False, stream_id)
data_len -= payload_size
payload_size = self.get_max_payload_size()
self.write_data_frame(data.read(), last, stream_id)
def write_data_frame(self, data, last, stream_id=None):
with self.h2conn as connection:
connection.send_data(
stream_id=self.request.h2_stream_id if stream_id is None else stream_id,
data=data,
end_stream=last,
)
self.write(connection)
self.stream_ended = last
def write_push(self, promise_headers, push_stream_id=None, status=None, response_headers=None, response_data=None):
"""Write a push promise, and optionally write the push content.
This will write a push promise to the request stream. If you do not provide headers and data for the response,
then no response will be pushed, and you should push them yourself using the ID returned from this function
:param promise_headers: A list of header tuples that matches what the client would use to
request the pushed response
:param push_stream_id: The ID of the stream the response should be pushed to. If none given, will
use the next available id.
:param status: The status code of the response, REQUIRED if response_headers given
:param response_headers: The headers of the response
:param response_data: The response data.
:return: The ID of the push stream
"""
with self.h2conn as connection:
push_stream_id = push_stream_id if push_stream_id is not None else connection.get_next_available_stream_id()
connection.push_stream(self.request.h2_stream_id, push_stream_id, promise_headers)
self.write(connection)
has_data = response_data is not None
if response_headers is not None:
assert status is not None
self.write_headers(response_headers, status, stream_id=push_stream_id, last=not has_data)
if has_data:
self.write_data(response_data, last=True, stream_id=push_stream_id)
return push_stream_id
def end_stream(self, stream_id=None):
"""Ends the stream with the given ID, or the one that request was made on if no ID given."""
with self.h2conn as connection:
connection.end_stream(stream_id if stream_id is not None else self.request.h2_stream_id)
self.write(connection)
self.stream_ended = True
def write_raw_header_frame(self, headers, stream_id=None, end_stream=False, end_headers=False, frame_cls=HeadersFrame):
"""
Ignores the statemachine of the stream and sends a HEADER frame regardless.
Unlike `write_headers`, this does not check to see if a stream is in the correct state to have HEADER frames
sent through to it. It will build a HEADER frame and send it without using the H2 Connection object other than
to HPACK encode the headers.
:param headers: List of (header, value) tuples
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_stream: Set to True to add END_STREAM flag to frame
:param end_headers: Set to True to add END_HEADERS flag to frame
"""
if not stream_id:
stream_id = self.request.h2_stream_id
header_t = []
for header, value in headers:
header_t.append(HeaderTuple(header, value))
with self.h2conn as connection:
frame = frame_cls(stream_id, data=connection.encoder.encode(header_t))
if end_stream:
self.stream_ended = True
frame.flags.add('END_STREAM')
if end_headers:
frame.flags.add('END_HEADERS')
data = frame.serialize()
self.write_raw(data)
def write_raw_data_frame(self, data, stream_id=None, end_stream=False):
"""
Ignores the statemachine of the stream and sends a DATA frame regardless.
Unlike `write_data`, this does not check to see if a stream is in the correct state to have DATA frames
sent through to it. It will build a DATA frame and send it without using the H2 Connection object. It will
not perform any flow control checks.
:param data: The data to be sent in the frame
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_stream: Set to True to add END_STREAM flag to frame
"""
if not stream_id:
stream_id = self.request.h2_stream_id
frame = DataFrame(stream_id, data=data)
if end_stream:
self.stream_ended = True
frame.flags.add('END_STREAM')
data = frame.serialize()
self.write_raw(data)
def write_raw_continuation_frame(self, headers, stream_id=None, end_headers=False):
"""
Ignores the statemachine of the stream and sends a CONTINUATION frame regardless.
This provides the ability to create and write a CONTINUATION frame to the stream, which is not exposed by
`write_headers` as the h2 library handles the split between HEADER and CONTINUATION internally. Will perform
HPACK encoding on the headers.
:param headers: List of (header, value) tuples
:param stream_id: Id of stream to send frame on. Will use the request stream ID if None
:param end_headers: Set to True to add END_HEADERS flag to frame
"""
self.write_raw_header_frame(headers, stream_id=stream_id, end_headers=end_headers, frame_cls=ContinuationFrame)
def get_max_payload_size(self, stream_id=None):
"""Returns the maximum size of a payload for the given stream."""
stream_id = stream_id if stream_id is not None else self.request.h2_stream_id
with self.h2conn as connection:
return min(connection.remote_settings.max_frame_size, connection.local_flow_control_window(stream_id)) - 9
def write(self, connection):
self.content_written = True
data = connection.data_to_send()
self.socket.sendall(data)
def write_raw(self, raw_data):
"""Used for sending raw bytes/data through the socket"""
self.content_written = True
self.socket.sendall(raw_data)
def decode(self, data):
"""Convert bytes to unicode according to response.encoding."""
if isinstance(data, binary_type):
return data.decode(self._response.encoding)
elif isinstance(data, text_type):
return data
else:
raise ValueError(type(data))
def encode(self, data):
"""Convert unicode to bytes according to response.encoding."""
if isinstance(data, binary_type):
return data
elif isinstance(data, text_type):
return data.encode(self._response.encoding)
else:
raise ValueError
class ResponseWriter(object):
"""Object providing an API to write out a HTTP response.
:param handler: The RequestHandler being used.
:param response: The Response associated with this writer."""
def __init__(self, handler, response):
self._wfile = handler.wfile
self._response = response
self._handler = handler
self._status_written = False
self._headers_seen = set()
self._headers_complete = False
self.content_written = False
self.request = response.request
self.file_chunk_size = 32 * 1024
self.default_status = 200
def _seen_header(self, name):
return self.encode(name.lower()) in self._headers_seen
def write_status(self, code, message=None):
"""Write out the status line of a response.
:param code: The integer status code of the response.
:param message: The message of the response. Defaults to the message commonly used
with the status code."""
if message is None:
if code in response_codes:
message = response_codes[code][0]
else:
message = ''
self.write(b"%s %d %s\r\n" %
(isomorphic_encode(self._response.request.protocol_version), code, isomorphic_encode(message)))
self._status_written = True
def write_header(self, name, value):
"""Write out a single header for the response.
If a status has not been written, a default status will be written (currently 200)
:param name: Name of the header field
:param value: Value of the header field
:return: A boolean indicating whether the write succeeds
"""
if not self._status_written:
self.write_status(self.default_status)
self._headers_seen.add(self.encode(name.lower()))
if not self.write(name):
return False
if not self.write(b": "):
return False
if isinstance(value, int):
if not self.write(text_type(value)):
return False
elif not self.write(value):
return False
return self.write(b"\r\n")
def write_default_headers(self):
for name, f in [("Server", self._handler.version_string),
("Date", self._handler.date_time_string)]:
if not self._seen_header(name):
if not self.write_header(name, f()):
return False
if (isinstance(self._response.content, (binary_type, text_type)) and
not self._seen_header("content-length")):
#Would be nice to avoid double-encoding here
if not self.write_header("Content-Length", len(self.encode(self._response.content))):
return False
return True
def end_headers(self):
"""Finish writing headers and write the separator.
Unless add_required_headers on the response is False,
this will also add HTTP-mandated headers that have not yet been supplied
to the response headers.
:return: A boolean indicating whether the write succeeds
"""
if self._response.add_required_headers:
if not self.write_default_headers():
return False
if not self.write("\r\n"):
return False
if not self._seen_header("content-length"):
self._response.close_connection = True
self._headers_complete = True
return True
def write_content(self, data):
"""Write the body of the response.
HTTP-mandated headers will be automatically added with status default to 200 if they have
not been explicitly set.
:return: A boolean indicating whether the write succeeds
"""
if not self._status_written:
self.write_status(self.default_status)
if not self._headers_complete:
self._response.content = data
self.end_headers()
return self.write_raw_content(data)
def write_raw_content(self, data):
"""Writes the data 'as is'"""
if data is None:
raise ValueError('data cannot be None')
if isinstance(data, (text_type, binary_type)):
# Deliberately allows both text and binary types. See `self.encode`.
return self.write(data)
else:
return self.write_content_file(data)
def write(self, data):
"""Write directly to the response, converting unicode to bytes
according to response.encoding.
:return: A boolean indicating whether the write succeeds
"""
self.content_written = True
try:
self._wfile.write(self.encode(data))
return True
except socket.error:
# This can happen if the socket got closed by the remote end
return False
def write_content_file(self, data):
"""Write a file-like object directly to the response in chunks."""
self.content_written = True
success = True
while True:
buf = data.read(self.file_chunk_size)
if not buf:
success = False
break
try:
self._wfile.write(buf)
except socket.error:
success = False
break
data.close()
return success
def encode(self, data):
"""Convert unicode to bytes according to response.encoding."""
if isinstance(data, binary_type):
return data
elif isinstance(data, text_type):
return data.encode(self._response.encoding)
else:
raise ValueError("data %r should be text or binary, but is %s" % (data, type(data)))

Просмотреть файл

@ -0,0 +1,179 @@
import itertools
import re
import sys
from .logger import get_logger
from six import binary_type, text_type
any_method = object()
class RouteTokenizer(object):
def literal(self, scanner, token):
return ("literal", token)
def slash(self, scanner, token):
return ("slash", None)
def group(self, scanner, token):
return ("group", token[1:-1])
def star(self, scanner, token):
return ("star", token[1:-3])
def scan(self, input_str):
scanner = re.Scanner([(r"/", self.slash),
(r"{\w*}", self.group),
(r"\*", self.star),
(r"(?:\\.|[^{\*/])*", self.literal),])
return scanner.scan(input_str)
class RouteCompiler(object):
def __init__(self):
self.reset()
def reset(self):
self.star_seen = False
def compile(self, tokens):
self.reset()
func_map = {"slash":self.process_slash,
"literal":self.process_literal,
"group":self.process_group,
"star":self.process_star}
re_parts = ["^"]
if not tokens or tokens[0][0] != "slash":
tokens = itertools.chain([("slash", None)], tokens)
for token in tokens:
re_parts.append(func_map[token[0]](token))
if self.star_seen:
re_parts.append(")")
re_parts.append("$")
return re.compile("".join(re_parts))
def process_literal(self, token):
return re.escape(token[1])
def process_slash(self, token):
return "/"
def process_group(self, token):
if self.star_seen:
raise ValueError("Group seen after star in regexp")
return "(?P<%s>[^/]+)" % token[1]
def process_star(self, token):
if self.star_seen:
raise ValueError("Star seen after star in regexp")
self.star_seen = True
return "(.*"
def compile_path_match(route_pattern):
"""tokens: / or literal or match or *"""
tokenizer = RouteTokenizer()
tokens, unmatched = tokenizer.scan(route_pattern)
assert unmatched == "", unmatched
compiler = RouteCompiler()
return compiler.compile(tokens)
class Router(object):
"""Object for matching handler functions to requests.
:param doc_root: Absolute path of the filesystem location from
which to serve tests
:param routes: Initial routes to add; a list of three item tuples
(method, path_pattern, handler_function), defined
as for register()
"""
def __init__(self, doc_root, routes):
self.doc_root = doc_root
self.routes = []
self.logger = get_logger()
# Add the doc_root to the Python path, so that any Python handler can
# correctly locate helper scripts (see RFC_TO_BE_LINKED).
#
# TODO: In a perfect world, Router would not need to know about this
# and the handler itself would take care of it. Currently, however, we
# treat handlers like functions and so there's no easy way to do that.
if self.doc_root not in sys.path:
sys.path.insert(0, self.doc_root)
for route in reversed(routes):
self.register(*route)
def register(self, methods, path, handler):
r"""Register a handler for a set of paths.
:param methods: Set of methods this should match. "*" is a
special value indicating that all methods should
be matched.
:param path_pattern: Match pattern that will be used to determine if
a request path matches this route. Match patterns
consist of either literal text, match groups,
denoted {name}, which match any character except /,
and, at most one \*, which matches and character and
creates a match group to the end of the string.
If there is no leading "/" on the pattern, this is
automatically implied. For example::
api/{resource}/*.json
Would match `/api/test/data.json` or
`/api/test/test2/data.json`, but not `/api/test/data.py`.
The match groups are made available in the request object
as a dictionary through the route_match property. For
example, given the route pattern above and the path
`/api/test/data.json`, the route_match property would
contain::
{"resource": "test", "*": "data.json"}
:param handler: Function that will be called to process matching
requests. This must take two parameters, the request
object and the response object.
"""
if isinstance(methods, (binary_type, text_type)) or methods is any_method:
methods = [methods]
for method in methods:
self.routes.append((method, compile_path_match(path), handler))
self.logger.debug("Route pattern: %s" % self.routes[-1][1].pattern)
def get_handler(self, request):
"""Get a handler for a request or None if there is no handler.
:param request: Request to get a handler for.
:rtype: Callable or None
"""
for method, regexp, handler in reversed(self.routes):
if (request.method == method or
method in (any_method, "*") or
(request.method == "HEAD" and method == "GET")):
m = regexp.match(request.url_parts.path)
if m:
if not hasattr(handler, "__class__"):
name = handler.__name__
else:
name = handler.__class__.__name__
self.logger.debug("Found handler %s" % name)
match_parts = m.groupdict().copy()
if len(match_parts) < len(m.groups()):
match_parts["*"] = m.groups()[-1]
request.route_match = match_parts
return handler
return None

Просмотреть файл

@ -0,0 +1,6 @@
from . import handlers
from .router import any_method
routes = [(any_method, "*.py", handlers.python_script_handler),
("GET", "*.asis", handlers.as_is_handler),
("GET", "*", handlers.file_handler),
]

Просмотреть файл

@ -0,0 +1,896 @@
from six.moves import BaseHTTPServer
import errno
import os
import socket
from six.moves.socketserver import ThreadingMixIn
import ssl
import sys
import threading
import time
import traceback
from six import binary_type, text_type
import uuid
from collections import OrderedDict
from six.moves.queue import Queue
from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.events import RequestReceived, ConnectionTerminated, DataReceived, StreamReset, StreamEnded
from h2.exceptions import StreamClosedError, ProtocolError
from h2.settings import SettingCodes
from h2.utilities import extract_method_header
from six.moves.urllib.parse import urlsplit, urlunsplit
from mod_pywebsocket import dispatch
from mod_pywebsocket.handshake import HandshakeException
from . import routes as default_routes
from .config import ConfigBuilder
from .logger import get_logger
from .request import Server, Request, H2Request
from .response import Response, H2Response
from .router import Router
from .utils import HTTPException, isomorphic_decode, isomorphic_encode
from .constants import h2_headers
from .ws_h2_handshake import WsH2Handshaker
# We need to stress test that browsers can send/receive many headers (there is
# no specified limit), but the Python stdlib has an arbitrary limit of 100
# headers. Hitting the limit would produce an exception that is silently caught
# in Python 2 but leads to HTTP 431 in Python 3, so we monkey patch it higher.
# https://bugs.python.org/issue26586
# https://github.com/web-platform-tests/wpt/pull/24451
from six.moves import http_client
assert isinstance(getattr(http_client, '_MAXHEADERS'), int)
setattr(http_client, '_MAXHEADERS', 512)
"""
HTTP server designed for testing purposes.
The server is designed to provide flexibility in the way that
requests are handled, and to provide control both of exactly
what bytes are put on the wire for the response, and in the
timing of sending those bytes.
The server is based on the stdlib HTTPServer, but with some
notable differences in the way that requests are processed.
Overall processing is handled by a WebTestRequestHandler,
which is a subclass of BaseHTTPRequestHandler. This is responsible
for parsing the incoming request. A RequestRewriter is then
applied and may change the request data if it matches a
supplied rule.
Once the request data had been finalised, Request and Response
objects are constructed. These are used by the other parts of the
system to read information about the request and manipulate the
response.
Each request is handled by a particular handler function. The
mapping between Request and the appropriate handler is determined
by a Router. By default handlers are installed to interpret files
under the document root with .py extensions as executable python
files (see handlers.py for the api for such files), .asis files as
bytestreams to be sent literally and all other files to be served
statically.
The handler functions are responsible for either populating the
fields of the response object, which will then be written when the
handler returns, or for directly writing to the output stream.
"""
class RequestRewriter(object):
def __init__(self, rules):
"""Object for rewriting the request path.
:param rules: Initial rules to add; a list of three item tuples
(method, input_path, output_path), defined as for
register()
"""
self.rules = {}
for rule in reversed(rules):
self.register(*rule)
self.logger = get_logger()
def register(self, methods, input_path, output_path):
"""Register a rewrite rule.
:param methods: Set of methods this should match. "*" is a
special value indicating that all methods should
be matched.
:param input_path: Path to match for the initial request.
:param output_path: Path to replace the input path with in
the request.
"""
if isinstance(methods, (binary_type, text_type)):
methods = [methods]
self.rules[input_path] = (methods, output_path)
def rewrite(self, request_handler):
"""Rewrite the path in a BaseHTTPRequestHandler instance, if
it matches a rule.
:param request_handler: BaseHTTPRequestHandler for which to
rewrite the request.
"""
split_url = urlsplit(request_handler.path)
if split_url.path in self.rules:
methods, destination = self.rules[split_url.path]
if "*" in methods or request_handler.command in methods:
self.logger.debug("Rewriting request path %s to %s" %
(request_handler.path, destination))
new_url = list(split_url)
new_url[2] = destination
new_url = urlunsplit(new_url)
request_handler.path = new_url
class WebTestServer(ThreadingMixIn, BaseHTTPServer.HTTPServer):
allow_reuse_address = True
acceptable_errors = (errno.EPIPE, errno.ECONNABORTED)
request_queue_size = 2000
# Ensure that we don't hang on shutdown waiting for requests
daemon_threads = True
def __init__(self, server_address, request_handler_cls,
router, rewriter, bind_address, ws_doc_root=None,
config=None, use_ssl=False, key_file=None, certificate=None,
encrypt_after_connect=False, latency=None, http2=False, **kwargs):
"""Server for HTTP(s) Requests
:param server_address: tuple of (server_name, port)
:param request_handler_cls: BaseHTTPRequestHandler-like class to use for
handling requests.
:param router: Router instance to use for matching requests to handler
functions
:param rewriter: RequestRewriter-like instance to use for preprocessing
requests before they are routed
:param config: Dictionary holding environment configuration settings for
handlers to read, or None to use the default values.
:param use_ssl: Boolean indicating whether the server should use SSL
:param key_file: Path to key file to use if SSL is enabled.
:param certificate: Path to certificate to use if SSL is enabled.
:param ws_doc_root: Document root for websockets
:param encrypt_after_connect: For each connection, don't start encryption
until a CONNECT message has been received.
This enables the server to act as a
self-proxy.
:param bind_address True to bind the server to both the IP address and
port specified in the server_address parameter.
False to bind the server only to the port in the
server_address parameter, but not to the address.
:param latency: Delay in ms to wait before serving each response, or
callable that returns a delay in ms
"""
self.router = router
self.rewriter = rewriter
self.scheme = "http2" if http2 else "https" if use_ssl else "http"
self.logger = get_logger()
self.latency = latency
if bind_address:
hostname_port = server_address
else:
hostname_port = ("",server_address[1])
#super doesn't work here because BaseHTTPServer.HTTPServer is old-style
BaseHTTPServer.HTTPServer.__init__(self, hostname_port, request_handler_cls, **kwargs)
if config is not None:
Server.config = config
else:
self.logger.debug("Using default configuration")
with ConfigBuilder(browser_host=server_address[0],
ports={"http": [self.server_address[1]]}) as config:
assert config["ssl_config"] is None
Server.config = config
self.ws_doc_root = ws_doc_root
self.key_file = key_file
self.certificate = certificate
self.encrypt_after_connect = use_ssl and encrypt_after_connect
if use_ssl and not encrypt_after_connect:
if http2:
ssl_context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(keyfile=self.key_file, certfile=self.certificate)
ssl_context.set_alpn_protocols(['h2'])
self.socket = ssl_context.wrap_socket(self.socket,
server_side=True)
else:
self.socket = ssl.wrap_socket(self.socket,
keyfile=self.key_file,
certfile=self.certificate,
server_side=True)
def handle_error(self, request, client_address):
error = sys.exc_info()[1]
if ((isinstance(error, socket.error) and
isinstance(error.args, tuple) and
error.args[0] in self.acceptable_errors) or
(isinstance(error, IOError) and
error.errno in self.acceptable_errors)):
pass # remote hang up before the result is sent
else:
self.logger.error(traceback.format_exc())
class BaseWebTestRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""RequestHandler for WebTestHttpd"""
def __init__(self, *args, **kwargs):
self.logger = get_logger()
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
def finish_handling_h1(self, request_line_is_valid):
self.server.rewriter.rewrite(self)
request = Request(self)
response = Response(self, request)
if request.method == "CONNECT":
self.handle_connect(response)
return
if not request_line_is_valid:
response.set_error(414)
response.write()
return
self.logger.debug("%s %s" % (request.method, request.request_path))
handler = self.server.router.get_handler(request)
self.finish_handling(request, response, handler)
def finish_handling(self, request, response, handler):
# If the handler we used for the request had a non-default base path
# set update the doc_root of the request to reflect this
if hasattr(handler, "base_path") and handler.base_path:
request.doc_root = handler.base_path
if hasattr(handler, "url_base") and handler.url_base != "/":
request.url_base = handler.url_base
if self.server.latency is not None:
if callable(self.server.latency):
latency = self.server.latency()
else:
latency = self.server.latency
self.logger.warning("Latency enabled. Sleeping %i ms" % latency)
time.sleep(latency / 1000.)
if handler is None:
self.logger.debug("No Handler found!")
response.set_error(404)
else:
try:
handler(request, response)
except HTTPException as e:
response.set_error(e.code, e.message)
except Exception as e:
self.respond_with_error(response, e)
self.logger.debug("%i %s %s (%s) %i" % (response.status[0],
request.method,
request.request_path,
request.headers.get('Referer'),
request.raw_input.length))
if not response.writer.content_written:
response.write()
# If a python handler has been used, the old ones won't send a END_STR data frame, so this
# allows for backwards compatibility by accounting for these handlers that don't close streams
if isinstance(response, H2Response) and not response.writer.stream_ended:
response.writer.end_stream()
# If we want to remove this in the future, a solution is needed for
# scripts that produce a non-string iterable of content, since these
# can't set a Content-Length header. A notable example of this kind of
# problem is with the trickle pipe i.e. foo.js?pipe=trickle(d1)
if response.close_connection:
self.close_connection = True
if not self.close_connection:
# Ensure that the whole request has been read from the socket
request.raw_input.read()
def handle_connect(self, response):
self.logger.debug("Got CONNECT")
response.status = 200
response.write()
if self.server.encrypt_after_connect:
self.logger.debug("Enabling SSL for connection")
self.request = ssl.wrap_socket(self.connection,
keyfile=self.server.key_file,
certfile=self.server.certificate,
server_side=True)
self.setup()
return
def respond_with_error(self, response, e):
message = str(e)
if message:
err = [message]
else:
err = []
err.append(traceback.format_exc())
response.set_error(500, "\n".join(err))
class Http2WebTestRequestHandler(BaseWebTestRequestHandler):
protocol_version = "HTTP/2.0"
def handle_one_request(self):
"""
This is the main HTTP/2.0 Handler.
When a browser opens a connection to the server
on the HTTP/2.0 port, the server enters this which will initiate the h2 connection
and keep running throughout the duration of the interaction, and will read/write directly
from the socket.
Because there can be multiple H2 connections active at the same
time, a UUID is created for each so that it is easier to tell them apart in the logs.
"""
config = H2Configuration(client_side=False)
self.conn = H2ConnectionGuard(H2Connection(config=config))
self.close_connection = False
# Generate a UUID to make it easier to distinguish different H2 connection debug messages
self.uid = str(uuid.uuid4())[:8]
self.logger.debug('(%s) Initiating h2 Connection' % self.uid)
with self.conn as connection:
# Bootstrapping WebSockets with HTTP/2 specification requires
# ENABLE_CONNECT_PROTOCOL to be set in order to enable WebSocket
# over HTTP/2
new_settings = dict(connection.local_settings)
new_settings[SettingCodes.ENABLE_CONNECT_PROTOCOL] = 1
connection.local_settings.update(new_settings)
connection.local_settings.acknowledge()
connection.initiate_connection()
data = connection.data_to_send()
window_size = connection.remote_settings.initial_window_size
self.request.sendall(data)
# Dict of { stream_id: (thread, queue) }
stream_queues = {}
try:
while not self.close_connection:
data = self.request.recv(window_size)
if data == '':
self.logger.debug('(%s) Socket Closed' % self.uid)
self.close_connection = True
continue
with self.conn as connection:
frames = connection.receive_data(data)
window_size = connection.remote_settings.initial_window_size
self.logger.debug('(%s) Frames Received: ' % self.uid + str(frames))
for frame in frames:
if isinstance(frame, ConnectionTerminated):
self.logger.debug('(%s) Connection terminated by remote peer ' % self.uid)
self.close_connection = True
# Flood all the streams with connection terminated, this will cause them to stop
for stream_id, (thread, queue) in stream_queues.items():
queue.put(frame)
elif hasattr(frame, 'stream_id'):
if frame.stream_id not in stream_queues:
queue = Queue()
stream_queues[frame.stream_id] = (self.start_stream_thread(frame, queue), queue)
stream_queues[frame.stream_id][1].put(frame)
if isinstance(frame, StreamEnded) or (hasattr(frame, "stream_ended") and frame.stream_ended):
del stream_queues[frame.stream_id]
except (socket.timeout, socket.error) as e:
self.logger.error('(%s) Closing Connection - \n%s' % (self.uid, str(e)))
if not self.close_connection:
self.close_connection = True
for stream_id, (thread, queue) in stream_queues.items():
queue.put(None)
except Exception as e:
self.logger.error('(%s) Unexpected Error - \n%s' % (self.uid, str(e)))
finally:
for stream_id, (thread, queue) in stream_queues.items():
thread.join()
def _is_extended_connect_frame(self, frame):
if not isinstance(frame, RequestReceived):
return False
method = extract_method_header(frame.headers)
if method != b"CONNECT":
return False
protocol = ""
for key, value in frame.headers:
if key in (b':protocol', u':protocol'):
protocol = isomorphic_encode(value)
break
if protocol != b"websocket":
raise ProtocolError("Invalid protocol %s with CONNECT METHOD" % (protocol,))
return True
def start_stream_thread(self, frame, queue):
"""
This starts a new thread to handle frames for a specific stream.
:param frame: The first frame on the stream
:param queue: A queue object that the thread will use to check for new frames
:return: The thread object that has already been started
"""
if self._is_extended_connect_frame(frame):
target = Http2WebTestRequestHandler._stream_ws_thread
else:
target = Http2WebTestRequestHandler._stream_thread
t = threading.Thread(
target=target,
args=(self, frame.stream_id, queue)
)
t.start()
return t
def _stream_ws_thread(self, stream_id, queue):
frame = queue.get(True, None)
rfile, wfile = os.pipe()
rfile, wfile = os.fdopen(rfile, 'rb'), os.fdopen(wfile, 'wb', 0) # needs to be unbuffer for websockets
stream_handler = H2HandlerCopy(self, frame, rfile)
h2request = H2Request(stream_handler)
h2response = H2Response(stream_handler, h2request)
dispatcher = dispatch.Dispatcher(self.server.ws_doc_root, None, False)
if not dispatcher.get_handler_suite(stream_handler.path):
h2response.set_error(404)
h2response.write()
return
request_wrapper = _WebSocketRequest(stream_handler, h2response)
handshaker = WsH2Handshaker(request_wrapper, dispatcher)
try:
handshaker.do_handshake()
except HandshakeException as e:
self.logger.info('Handshake failed for error: %s', e)
h2response.set_error(e.status)
h2response.write()
return
# h2 Handshaker prepares the headers but does not send them down the
# wire. Flush the headers here.
h2response.write_status_headers()
request_wrapper._dispatcher = dispatcher
# we need two threads:
# - one to handle the frame queue
# - one to handle the request (dispatcher.transfer_data is blocking)
# the alternative is to have only one (blocking) thread. That thread
# will call transfer_data. That would require a special case in
# handle_one_request, to bypass the queue and write data to wfile
# directly.
t = threading.Thread(
target=Http2WebTestRequestHandler._stream_ws_sub_thread,
args=(self, request_wrapper, stream_handler, queue)
)
t.start()
while not self.close_connection:
frame = queue.get(True, None)
if isinstance(frame, DataReceived):
wfile.write(frame.data)
if frame.stream_ended:
raise NotImplementedError("frame.stream_ended")
wfile.close()
elif frame is None or isinstance(frame, (StreamReset, StreamEnded, ConnectionTerminated)):
self.logger.debug('(%s - %s) Stream Reset, Thread Closing' % (self.uid, stream_id))
break
t.join()
def _stream_ws_sub_thread(self, request, stream_handler, queue):
dispatcher = request._dispatcher
dispatcher.transfer_data(request)
stream_id = stream_handler.h2_stream_id
with stream_handler.conn as connection:
try:
connection.end_stream(stream_id)
data = connection.data_to_send()
stream_handler.request.sendall(data)
except StreamClosedError: # maybe the stream has already been closed
pass
queue.put(None)
def _stream_thread(self, stream_id, queue):
"""
This thread processes frames for a specific stream. It waits for frames to be placed
in the queue, and processes them. When it receives a request frame, it will start processing
immediately, even if there are data frames to follow. One of the reasons for this is that it
can detect invalid requests before needing to read the rest of the frames.
"""
# The file-like pipe object that will be used to share data to request object if data is received
wfile = None
request = None
response = None
req_handler = None
while not self.close_connection:
# Wait for next frame, blocking
frame = queue.get(True, None)
self.logger.debug('(%s - %s) %s' % (self.uid, stream_id, str(frame)))
if isinstance(frame, RequestReceived):
rfile, wfile = os.pipe()
rfile, wfile = os.fdopen(rfile, 'rb'), os.fdopen(wfile, 'wb')
stream_handler = H2HandlerCopy(self, frame, rfile)
stream_handler.server.rewriter.rewrite(stream_handler)
request = H2Request(stream_handler)
response = H2Response(stream_handler, request)
req_handler = stream_handler.server.router.get_handler(request)
if hasattr(req_handler, "frame_handler"):
# Convert this to a handler that will utilise H2 specific functionality, such as handling individual frames
req_handler = self.frame_handler(request, response, req_handler)
if hasattr(req_handler, 'handle_headers'):
req_handler.handle_headers(frame, request, response)
elif isinstance(frame, DataReceived):
wfile.write(frame.data)
if hasattr(req_handler, 'handle_data'):
req_handler.handle_data(frame, request, response)
if frame.stream_ended:
wfile.close()
elif frame is None or isinstance(frame, (StreamReset, StreamEnded, ConnectionTerminated)):
self.logger.debug('(%s - %s) Stream Reset, Thread Closing' % (self.uid, stream_id))
break
if request is not None:
request.frames.append(frame)
if hasattr(frame, "stream_ended") and frame.stream_ended:
self.finish_handling(request, response, req_handler)
def frame_handler(self, request, response, handler):
try:
return handler.frame_handler(request)
except HTTPException as e:
response.set_error(e.code, e.message)
response.write()
except Exception as e:
self.respond_with_error(response, e)
response.write()
class H2ConnectionGuard(object):
"""H2Connection objects are not threadsafe, so this keeps thread safety"""
lock = threading.Lock()
def __init__(self, obj):
assert isinstance(obj, H2Connection)
self.obj = obj
def __enter__(self):
self.lock.acquire()
return self.obj
def __exit__(self, exception_type, exception_value, traceback):
self.lock.release()
class H2Headers(dict):
def __init__(self, headers):
self.raw_headers = OrderedDict()
for key, val in headers:
key = isomorphic_decode(key)
val = isomorphic_decode(val)
self.raw_headers[key] = val
dict.__setitem__(self, self._convert_h2_header_to_h1(key), val)
def _convert_h2_header_to_h1(self, header_key):
if header_key[1:] in h2_headers and header_key[0] == ':':
return header_key[1:]
else:
return header_key
# TODO This does not seem relevant for H2 headers, so using a dummy function for now
def getallmatchingheaders(self, header):
return ['dummy function']
class H2HandlerCopy(object):
def __init__(self, handler, req_frame, rfile):
self.headers = H2Headers(req_frame.headers)
self.command = self.headers['method']
self.path = self.headers['path']
self.h2_stream_id = req_frame.stream_id
self.server = handler.server
self.protocol_version = handler.protocol_version
self.client_address = handler.client_address
self.raw_requestline = ''
self.rfile = rfile
self.request = handler.request
self.conn = handler.conn
class Http1WebTestRequestHandler(BaseWebTestRequestHandler):
protocol_version = "HTTP/1.1"
def handle_one_request(self):
response = None
try:
self.close_connection = False
request_line_is_valid = self.get_request_line()
if self.close_connection:
return
request_is_valid = self.parse_request()
if not request_is_valid:
#parse_request() actually sends its own error responses
return
self.finish_handling_h1(request_line_is_valid)
except socket.timeout as e:
self.log_error("Request timed out: %r", e)
self.close_connection = True
return
except Exception:
err = traceback.format_exc()
if response:
response.set_error(500, err)
response.write()
self.logger.error(err)
def get_request_line(self):
try:
self.raw_requestline = self.rfile.readline(65537)
except socket.error:
self.close_connection = True
return False
if len(self.raw_requestline) > 65536:
self.requestline = ''
self.request_version = ''
self.command = ''
return False
if not self.raw_requestline:
self.close_connection = True
return True
class WebTestHttpd(object):
"""
:param host: Host from which to serve (default: 127.0.0.1)
:param port: Port from which to serve (default: 8000)
:param server_cls: Class to use for the server (default depends on ssl vs non-ssl)
:param handler_cls: Class to use for the RequestHandler
:param use_ssl: Use a SSL server if no explicit server_cls is supplied
:param key_file: Path to key file to use if ssl is enabled
:param certificate: Path to certificate file to use if ssl is enabled
:param encrypt_after_connect: For each connection, don't start encryption
until a CONNECT message has been received.
This enables the server to act as a
self-proxy.
:param router_cls: Router class to use when matching URLs to handlers
:param doc_root: Document root for serving files
:param ws_doc_root: Document root for websockets
:param routes: List of routes with which to initialize the router
:param rewriter_cls: Class to use for request rewriter
:param rewrites: List of rewrites with which to initialize the rewriter_cls
:param config: Dictionary holding environment configuration settings for
handlers to read, or None to use the default values.
:param bind_address: Boolean indicating whether to bind server to IP address.
:param latency: Delay in ms to wait before serving each response, or
callable that returns a delay in ms
HTTP server designed for testing scenarios.
Takes a router class which provides one method get_handler which takes a Request
and returns a handler function.
.. attribute:: host
The host name or ip address of the server
.. attribute:: port
The port on which the server is running
.. attribute:: router
The Router object used to associate requests with resources for this server
.. attribute:: rewriter
The Rewriter object used for URL rewriting
.. attribute:: use_ssl
Boolean indicating whether the server is using ssl
.. attribute:: started
Boolean indicating whether the server is running
"""
def __init__(self, host="127.0.0.1", port=8000,
server_cls=None, handler_cls=Http1WebTestRequestHandler,
use_ssl=False, key_file=None, certificate=None, encrypt_after_connect=False,
router_cls=Router, doc_root=os.curdir, ws_doc_root=None, routes=None,
rewriter_cls=RequestRewriter, bind_address=True, rewrites=None,
latency=None, config=None, http2=False):
if routes is None:
routes = default_routes.routes
self.host = host
self.router = router_cls(doc_root, routes)
self.rewriter = rewriter_cls(rewrites if rewrites is not None else [])
self.use_ssl = use_ssl
self.http2 = http2
self.logger = get_logger()
if server_cls is None:
server_cls = WebTestServer
if use_ssl:
if not os.path.exists(key_file):
raise ValueError("SSL certificate not found: {}".format(key_file))
if not os.path.exists(certificate):
raise ValueError("SSL key not found: {}".format(certificate))
try:
self.httpd = server_cls((host, port),
handler_cls,
self.router,
self.rewriter,
config=config,
bind_address=bind_address,
ws_doc_root=ws_doc_root,
use_ssl=use_ssl,
key_file=key_file,
certificate=certificate,
encrypt_after_connect=encrypt_after_connect,
latency=latency,
http2=http2)
self.started = False
_host, self.port = self.httpd.socket.getsockname()
except Exception:
self.logger.critical("Failed to start HTTP server on port %s; "
"is something already using that port?" % port)
raise
def start(self, block=False):
"""Start the server.
:param block: True to run the server on the current thread, blocking,
False to run on a separate thread."""
http_type = "http2" if self.http2 else "https" if self.use_ssl else "http"
self.logger.info("Starting %s server on %s:%s" % (http_type, self.host, self.port))
self.started = True
if block:
self.httpd.serve_forever()
else:
self.server_thread = threading.Thread(target=self.httpd.serve_forever)
self.server_thread.setDaemon(True) # don't hang on exit
self.server_thread.start()
def stop(self):
"""
Stops the server.
If the server is not running, this method has no effect.
"""
if self.started:
try:
self.httpd.shutdown()
self.httpd.server_close()
self.server_thread.join()
self.server_thread = None
self.logger.info("Stopped http server on %s:%s" % (self.host, self.port))
except AttributeError:
pass
self.started = False
self.httpd = None
def get_url(self, path="/", query=None, fragment=None):
if not self.started:
return None
return urlunsplit(("http" if not self.use_ssl else "https",
"%s:%s" % (self.host, self.port),
path, query, fragment))
class _WebSocketConnection(object):
def __init__(self, request_handler, response):
"""Mimic mod_python mp_conn.
:param request_handler: A H2HandlerCopy instance.
:param response: A H2Response instance.
"""
self._request_handler = request_handler
self._response = response
self.remote_addr = self._request_handler.client_address
def write(self, data):
self._response.writer.write_data(data, False)
def read(self, length):
return self._request_handler.rfile.read(length)
class _WebSocketRequest(object):
def __init__(self, request_handler, response):
"""Mimic mod_python request.
:param request_handler: A H2HandlerCopy instance.
:param response: A H2Response instance.
"""
self.connection = _WebSocketConnection(request_handler, response)
self._response = response
self.uri = request_handler.path
self.unparsed_uri = request_handler.path
self.method = request_handler.command
# read headers from request_handler
self.headers_in = request_handler.headers
# write headers directly into H2Response
self.headers_out = response.headers
# proxies status to H2Response
@property
def status(self):
return self._response.status
@status.setter
def status(self, status):
self._response.status = status

Просмотреть файл

@ -0,0 +1,14 @@
from .base import NoSSLEnvironment
from .openssl import OpenSSLEnvironment
from .pregenerated import PregeneratedSSLEnvironment
environments = {"none": NoSSLEnvironment,
"openssl": OpenSSLEnvironment,
"pregenerated": PregeneratedSSLEnvironment}
def get_cls(name):
try:
return environments[name]
except KeyError:
raise ValueError("%s is not a valid SSL type." % name)

Просмотреть файл

@ -0,0 +1,17 @@
class NoSSLEnvironment(object):
ssl_enabled = False
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
def host_cert_path(self, hosts):
return None, None
def ca_cert_path(self, hosts):
return None

Просмотреть файл

@ -0,0 +1,439 @@
import functools
import os
import random
import shutil
import subprocess
import tempfile
from datetime import datetime, timedelta
from six import iteritems, PY2
# Amount of time beyond the present to consider certificates "expired." This
# allows certificates to be proactively re-generated in the "buffer" period
# prior to their exact expiration time.
CERT_EXPIRY_BUFFER = dict(hours=6)
def _ensure_str(s, encoding):
"""makes sure s is an instance of str, converting with encoding if needed"""
if isinstance(s, str):
return s
if PY2:
return s.encode(encoding)
else:
return s.decode(encoding)
class OpenSSL(object):
def __init__(self, logger, binary, base_path, conf_path, hosts, duration,
base_conf_path=None):
"""Context manager for interacting with OpenSSL.
Creates a config file for the duration of the context.
:param logger: stdlib logger or python structured logger
:param binary: path to openssl binary
:param base_path: path to directory for storing certificates
:param conf_path: path for configuration file storing configuration data
:param hosts: list of hosts to include in configuration (or None if not
generating host certificates)
:param duration: Certificate duration in days"""
self.base_path = base_path
self.binary = binary
self.conf_path = conf_path
self.base_conf_path = base_conf_path
self.logger = logger
self.proc = None
self.cmd = []
self.hosts = hosts
self.duration = duration
def __enter__(self):
with open(self.conf_path, "w") as f:
f.write(get_config(self.base_path, self.hosts, self.duration))
return self
def __exit__(self, *args, **kwargs):
os.unlink(self.conf_path)
def log(self, line):
if hasattr(self.logger, "process_output"):
self.logger.process_output(self.proc.pid if self.proc is not None else None,
line.decode("utf8", "replace"),
command=" ".join(self.cmd))
else:
self.logger.debug(line)
def __call__(self, cmd, *args, **kwargs):
"""Run a command using OpenSSL in the current context.
:param cmd: The openssl subcommand to run
:param *args: Additional arguments to pass to the command
"""
self.cmd = [self.binary, cmd]
if cmd != "x509":
self.cmd += ["-config", self.conf_path]
self.cmd += list(args)
# Copy the environment, converting to plain strings. Win32 StartProcess
# is picky about all the keys/values being str (on both Py2/3).
env = {}
for k, v in iteritems(os.environ):
env[_ensure_str(k, "utf8")] = _ensure_str(v, "utf8")
if self.base_conf_path is not None:
env["OPENSSL_CONF"] = _ensure_str(self.base_conf_path, "utf-8")
self.proc = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
env=env)
stdout, stderr = self.proc.communicate()
self.log(stdout)
if self.proc.returncode != 0:
raise subprocess.CalledProcessError(self.proc.returncode, self.cmd,
output=stdout)
self.cmd = []
self.proc = None
return stdout
def make_subject(common_name,
country=None,
state=None,
locality=None,
organization=None,
organization_unit=None):
args = [("country", "C"),
("state", "ST"),
("locality", "L"),
("organization", "O"),
("organization_unit", "OU"),
("common_name", "CN")]
rv = []
for var, key in args:
value = locals()[var]
if value is not None:
rv.append("/%s=%s" % (key, value.replace("/", "\\/")))
return "".join(rv)
def make_alt_names(hosts):
return ",".join("DNS:%s" % host for host in hosts)
def make_name_constraints(hosts):
return ",".join("permitted;DNS:%s" % host for host in hosts)
def get_config(root_dir, hosts, duration=30):
if hosts is None:
san_line = ""
constraints_line = ""
else:
san_line = "subjectAltName = %s" % make_alt_names(hosts)
constraints_line = "nameConstraints = " + make_name_constraints(hosts)
if os.path.sep == "\\":
# This seems to be needed for the Shining Light OpenSSL on
# Windows, at least.
root_dir = root_dir.replace("\\", "\\\\")
rv = """[ ca ]
default_ca = CA_default
[ CA_default ]
dir = %(root_dir)s
certs = $dir
new_certs_dir = $certs
crl_dir = $dir%(sep)scrl
database = $dir%(sep)sindex.txt
private_key = $dir%(sep)scacert.key
certificate = $dir%(sep)scacert.pem
serial = $dir%(sep)sserial
crldir = $dir%(sep)scrl
crlnumber = $dir%(sep)scrlnumber
crl = $crldir%(sep)scrl.pem
RANDFILE = $dir%(sep)sprivate%(sep)s.rand
x509_extensions = usr_cert
name_opt = ca_default
cert_opt = ca_default
default_days = %(duration)d
default_crl_days = %(duration)d
default_md = sha256
preserve = no
policy = policy_anything
copy_extensions = copy
[ policy_anything ]
countryName = optional
stateOrProvinceName = optional
localityName = optional
organizationName = optional
organizationalUnitName = optional
commonName = supplied
emailAddress = optional
[ req ]
default_bits = 2048
default_keyfile = privkey.pem
distinguished_name = req_distinguished_name
attributes = req_attributes
x509_extensions = v3_ca
# Passwords for private keys if not present they will be prompted for
# input_password = secret
# output_password = secret
string_mask = utf8only
req_extensions = v3_req
[ req_distinguished_name ]
countryName = Country Name (2 letter code)
countryName_default = AU
countryName_min = 2
countryName_max = 2
stateOrProvinceName = State or Province Name (full name)
stateOrProvinceName_default =
localityName = Locality Name (eg, city)
0.organizationName = Organization Name
0.organizationName_default = Web Platform Tests
organizationalUnitName = Organizational Unit Name (eg, section)
#organizationalUnitName_default =
commonName = Common Name (e.g. server FQDN or YOUR name)
commonName_max = 64
emailAddress = Email Address
emailAddress_max = 64
[ req_attributes ]
[ usr_cert ]
basicConstraints=CA:false
subjectKeyIdentifier=hash
authorityKeyIdentifier=keyid,issuer
[ v3_req ]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
extendedKeyUsage = serverAuth
%(san_line)s
[ v3_ca ]
basicConstraints = CA:true
subjectKeyIdentifier=hash
authorityKeyIdentifier=keyid:always,issuer:always
keyUsage = keyCertSign
%(constraints_line)s
""" % {"root_dir": root_dir,
"san_line": san_line,
"duration": duration,
"constraints_line": constraints_line,
"sep": os.path.sep.replace("\\", "\\\\")}
return rv
class OpenSSLEnvironment(object):
ssl_enabled = True
def __init__(self, logger, openssl_binary="openssl", base_path=None,
password="web-platform-tests", force_regenerate=False,
duration=30, base_conf_path=None):
"""SSL environment that creates a local CA and host certificate using OpenSSL.
By default this will look in base_path for existing certificates that are still
valid and only create new certificates if there aren't any. This behaviour can
be adjusted using the force_regenerate option.
:param logger: a stdlib logging compatible logger or mozlog structured logger
:param openssl_binary: Path to the OpenSSL binary
:param base_path: Path in which certificates will be stored. If None, a temporary
directory will be used and removed when the server shuts down
:param password: Password to use
:param force_regenerate: Always create a new certificate even if one already exists.
"""
self.logger = logger
self.temporary = False
if base_path is None:
base_path = tempfile.mkdtemp()
self.temporary = True
self.base_path = os.path.abspath(base_path)
self.password = password
self.force_regenerate = force_regenerate
self.duration = duration
self.base_conf_path = base_conf_path
self.path = None
self.binary = openssl_binary
self.openssl = None
self._ca_cert_path = None
self._ca_key_path = None
self.host_certificates = {}
def __enter__(self):
if not os.path.exists(self.base_path):
os.makedirs(self.base_path)
path = functools.partial(os.path.join, self.base_path)
with open(path("index.txt"), "w"):
pass
with open(path("serial"), "w") as f:
serial = "%x" % random.randint(0, 1000000)
if len(serial) % 2:
serial = "0" + serial
f.write(serial)
self.path = path
return self
def __exit__(self, *args, **kwargs):
if self.temporary:
shutil.rmtree(self.base_path)
def _config_openssl(self, hosts):
conf_path = self.path("openssl.cfg")
return OpenSSL(self.logger, self.binary, self.base_path, conf_path, hosts,
self.duration, self.base_conf_path)
def ca_cert_path(self, hosts):
"""Get the path to the CA certificate file, generating a
new one if needed"""
if self._ca_cert_path is None and not self.force_regenerate:
self._load_ca_cert()
if self._ca_cert_path is None:
self._generate_ca(hosts)
return self._ca_cert_path
def _load_ca_cert(self):
key_path = self.path("cacert.key")
cert_path = self.path("cacert.pem")
if self.check_key_cert(key_path, cert_path, None):
self.logger.info("Using existing CA cert")
self._ca_key_path, self._ca_cert_path = key_path, cert_path
def check_key_cert(self, key_path, cert_path, hosts):
"""Check that a key and cert file exist and are valid"""
if not os.path.exists(key_path) or not os.path.exists(cert_path):
return False
with self._config_openssl(hosts) as openssl:
end_date_str = openssl("x509",
"-noout",
"-enddate",
"-in", cert_path).split("=", 1)[1].strip()
# Not sure if this works in other locales
end_date = datetime.strptime(end_date_str, "%b %d %H:%M:%S %Y %Z")
time_buffer = timedelta(**CERT_EXPIRY_BUFFER)
# Because `strptime` does not account for time zone offsets, it is
# always in terms of UTC, so the current time should be calculated
# accordingly.
if end_date < datetime.utcnow() + time_buffer:
return False
#TODO: check the key actually signed the cert.
return True
def _generate_ca(self, hosts):
path = self.path
self.logger.info("Generating new CA in %s" % self.base_path)
key_path = path("cacert.key")
req_path = path("careq.pem")
cert_path = path("cacert.pem")
with self._config_openssl(hosts) as openssl:
openssl("req",
"-batch",
"-new",
"-newkey", "rsa:2048",
"-keyout", key_path,
"-out", req_path,
"-subj", make_subject("web-platform-tests"),
"-passout", "pass:%s" % self.password)
openssl("ca",
"-batch",
"-create_serial",
"-keyfile", key_path,
"-passin", "pass:%s" % self.password,
"-selfsign",
"-extensions", "v3_ca",
"-notext",
"-in", req_path,
"-out", cert_path)
os.unlink(req_path)
self._ca_key_path, self._ca_cert_path = key_path, cert_path
def host_cert_path(self, hosts):
"""Get a tuple of (private key path, certificate path) for a host,
generating new ones if necessary.
hosts must be a list of all hosts to appear on the certificate, with
the primary hostname first."""
hosts = tuple(sorted(hosts, key=lambda x:len(x)))
if hosts not in self.host_certificates:
if not self.force_regenerate:
key_cert = self._load_host_cert(hosts)
else:
key_cert = None
if key_cert is None:
key, cert = self._generate_host_cert(hosts)
else:
key, cert = key_cert
self.host_certificates[hosts] = key, cert
return self.host_certificates[hosts]
def _load_host_cert(self, hosts):
host = hosts[0]
key_path = self.path("%s.key" % host)
cert_path = self.path("%s.pem" % host)
# TODO: check that this cert was signed by the CA cert
if self.check_key_cert(key_path, cert_path, hosts):
self.logger.info("Using existing host cert")
return key_path, cert_path
def _generate_host_cert(self, hosts):
host = hosts[0]
if not self.force_regenerate:
self._load_ca_cert()
if self._ca_key_path is None:
self._generate_ca(hosts)
ca_key_path = self._ca_key_path
assert os.path.exists(ca_key_path)
path = self.path
req_path = path("wpt.req")
cert_path = path("%s.pem" % host)
key_path = path("%s.key" % host)
self.logger.info("Generating new host cert")
with self._config_openssl(hosts) as openssl:
openssl("req",
"-batch",
"-newkey", "rsa:2048",
"-keyout", key_path,
"-in", ca_key_path,
"-nodes",
"-out", req_path)
openssl("ca",
"-batch",
"-in", req_path,
"-passin", "pass:%s" % self.password,
"-subj", make_subject(host),
"-out", cert_path)
os.unlink(req_path)
return key_path, cert_path

Просмотреть файл

@ -0,0 +1,26 @@
class PregeneratedSSLEnvironment(object):
"""SSL environment to use with existing key/certificate files
e.g. when running on a server with a public domain name
"""
ssl_enabled = True
def __init__(self, logger, host_key_path, host_cert_path,
ca_cert_path=None):
self._ca_cert_path = ca_cert_path
self._host_key_path = host_key_path
self._host_cert_path = host_cert_path
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass
def host_cert_path(self, hosts):
"""Return the key and certificate paths for the host"""
return self._host_key_path, self._host_cert_path
def ca_cert_path(self, hosts):
"""Return the certificate path of the CA that signed the
host certificates, or None if that isn't known"""
return self._ca_cert_path

Просмотреть файл

@ -0,0 +1,204 @@
import base64
import json
import os
import six
import threading
import uuid
from multiprocessing.managers import AcquirerProxy, BaseManager, DictProxy
from six import text_type, binary_type
from .utils import isomorphic_encode
class ServerDictManager(BaseManager):
shared_data = {}
def _get_shared():
return ServerDictManager.shared_data
ServerDictManager.register("get_dict",
callable=_get_shared,
proxytype=DictProxy)
ServerDictManager.register('Lock', threading.Lock, AcquirerProxy)
class ClientDictManager(BaseManager):
pass
ClientDictManager.register("get_dict")
ClientDictManager.register("Lock")
class StashServer(object):
def __init__(self, address=None, authkey=None, mp_context=None):
self.address = address
self.authkey = authkey
self.manager = None
self.mp_context = mp_context
def __enter__(self):
self.manager, self.address, self.authkey = start_server(self.address,
self.authkey,
self.mp_context)
store_env_config(self.address, self.authkey)
def __exit__(self, *args, **kwargs):
if self.manager is not None:
self.manager.shutdown()
def load_env_config():
address, authkey = json.loads(os.environ["WPT_STASH_CONFIG"])
if isinstance(address, list):
address = tuple(address)
else:
address = str(address)
authkey = base64.b64decode(authkey)
return address, authkey
def store_env_config(address, authkey):
authkey = base64.b64encode(authkey)
os.environ["WPT_STASH_CONFIG"] = json.dumps((address, authkey.decode("ascii")))
def start_server(address=None, authkey=None, mp_context=None):
if isinstance(authkey, text_type):
authkey = authkey.encode("ascii")
kwargs = {}
if six.PY3 and mp_context is not None:
kwargs["ctx"] = mp_context
manager = ServerDictManager(address, authkey, **kwargs)
manager.start()
address = manager._address
if isinstance(address, bytes):
address = address.decode("ascii")
return (manager, address, manager._authkey)
class LockWrapper(object):
def __init__(self, lock):
self.lock = lock
def acquire(self):
self.lock.acquire()
def release(self):
self.lock.release()
def __enter__(self):
self.acquire()
def __exit__(self, *args, **kwargs):
self.release()
#TODO: Consider expiring values after some fixed time for long-running
#servers
class Stash(object):
"""Key-value store for persisting data across HTTP/S and WS/S requests.
This data store is specifically designed for persisting data across server
requests. The synchronization is achieved by using the BaseManager from
the multiprocessing module so different processes can acccess the same data.
Stash can be used interchangeably between HTTP, HTTPS, WS and WSS servers.
A thing to note about WS/S servers is that they require additional steps in
the handlers for accessing the same underlying shared data in the Stash.
This can usually be achieved by using load_env_config(). When using Stash
interchangeably between HTTP/S and WS/S request, the path part of the key
should be expliclitly specified if accessing the same key/value subset.
The store has several unusual properties. Keys are of the form (path,
uuid), where path is, by default, the path in the HTTP request and
uuid is a unique id. In addition, the store is write-once, read-once,
i.e. the value associated with a particular key cannot be changed once
written and the read operation (called "take") is destructive. Taken together,
these properties make it difficult for data to accidentally leak
between different resources or different requests for the same
resource.
"""
_proxy = None
lock = None
_initializing = threading.Lock()
def __init__(self, default_path, address=None, authkey=None):
self.default_path = default_path
self._get_proxy(address, authkey)
self.data = Stash._proxy
def _get_proxy(self, address=None, authkey=None):
if address is None and authkey is None:
Stash._proxy = {}
Stash.lock = threading.Lock()
# Initializing the proxy involves connecting to the remote process and
# retrieving two proxied objects. This process is not inherently
# atomic, so a lock must be used to make it so. Atomicity ensures that
# only one thread attempts to initialize the connection and that any
# threads running in parallel correctly wait for initialization to be
# fully complete.
with Stash._initializing:
if Stash.lock:
return
manager = ClientDictManager(address, authkey)
manager.connect()
Stash._proxy = manager.get_dict()
Stash.lock = LockWrapper(manager.Lock())
def _wrap_key(self, key, path):
if path is None:
path = self.default_path
# This key format is required to support using the path. Since the data
# passed into the stash can be a DictProxy which wouldn't detect
# changes when writing to a subdict.
if isinstance(key, binary_type):
# UUIDs are within the ASCII charset.
key = key.decode('ascii')
return (isomorphic_encode(path), uuid.UUID(key).bytes)
def put(self, key, value, path=None):
"""Place a value in the shared stash.
:param key: A UUID to use as the data's key.
:param value: The data to store. This can be any python object.
:param path: The path that has access to read the data (by default
the current request path)"""
if value is None:
raise ValueError("SharedStash value may not be set to None")
internal_key = self._wrap_key(key, path)
if internal_key in self.data:
raise StashError("Tried to overwrite existing shared stash value "
"for key %s (old value was %s, new value is %s)" %
(internal_key, self.data[internal_key], value))
else:
self.data[internal_key] = value
def take(self, key, path=None):
"""Remove a value from the shared stash and return it.
:param key: A UUID to use as the data's key.
:param path: The path that has access to read the data (by default
the current request path)"""
internal_key = self._wrap_key(key, path)
value = self.data.get(internal_key, None)
if value is not None:
try:
self.data.pop(internal_key)
except KeyError:
# Silently continue when pop error occurs.
pass
return value
class StashError(Exception):
pass

Просмотреть файл

@ -0,0 +1,165 @@
import socket
import sys
from six import binary_type, text_type
def isomorphic_decode(s):
"""Decodes a binary string into a text string using iso-8859-1.
Returns `unicode` in Python 2 and `str` in Python 3. The function is a
no-op if the argument already has a text type. iso-8859-1 is chosen because
it is an 8-bit encoding whose code points range from 0x0 to 0xFF and the
values are the same as the binary representations, so any binary string can
be decoded into and encoded from iso-8859-1 without any errors or data
loss. Python 3 also uses iso-8859-1 (or latin-1) extensively in http:
https://github.com/python/cpython/blob/273fc220b25933e443c82af6888eb1871d032fb8/Lib/http/client.py#L213
"""
if isinstance(s, text_type):
return s
if isinstance(s, binary_type):
return s.decode("iso-8859-1")
raise TypeError("Unexpected value (expecting string-like): %r" % s)
def isomorphic_encode(s):
"""Encodes a text-type string into binary data using iso-8859-1.
Returns `str` in Python 2 and `bytes` in Python 3. The function is a no-op
if the argument already has a binary type. This is the counterpart of
isomorphic_decode.
"""
if isinstance(s, binary_type):
return s
if isinstance(s, text_type):
return s.encode("iso-8859-1")
raise TypeError("Unexpected value (expecting string-like): %r" % s)
def invert_dict(dict):
rv = {}
for key, values in dict.items():
for value in values:
if value in rv:
raise ValueError
rv[value] = key
return rv
class HTTPException(Exception):
def __init__(self, code, message=""):
self.code = code
self.message = message
def _open_socket(host, port):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if port != 0:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(5)
return sock
def is_bad_port(port):
"""
Bad port as per https://fetch.spec.whatwg.org/#port-blocking
"""
return port in [
1, # tcpmux
7, # echo
9, # discard
11, # systat
13, # daytime
15, # netstat
17, # qotd
19, # chargen
20, # ftp-data
21, # ftp
22, # ssh
23, # telnet
25, # smtp
37, # time
42, # name
43, # nicname
53, # domain
77, # priv-rjs
79, # finger
87, # ttylink
95, # supdup
101, # hostriame
102, # iso-tsap
103, # gppitnp
104, # acr-nema
109, # pop2
110, # pop3
111, # sunrpc
113, # auth
115, # sftp
117, # uucp-path
119, # nntp
123, # ntp
135, # loc-srv / epmap
139, # netbios
143, # imap2
179, # bgp
389, # ldap
427, # afp (alternate)
465, # smtp (alternate)
512, # print / exec
513, # login
514, # shell
515, # printer
526, # tempo
530, # courier
531, # chat
532, # netnews
540, # uucp
548, # afp
554, # rtsp
556, # remotefs
563, # nntp+ssl
587, # smtp (outgoing)
601, # syslog-conn
636, # ldap+ssl
993, # ldap+ssl
995, # pop3+ssl
1720, # h323hostcall
1723, # pptp
2049, # nfs
3659, # apple-sasl
4045, # lockd
5060, # sip
5061, # sips
6000, # x11
6665, # irc (alternate)
6666, # irc (alternate)
6667, # irc (default)
6668, # irc (alternate)
6669, # irc (alternate)
6697, # irc+tls
]
def get_port(host=''):
host = host or '127.0.0.1'
port = 0
while True:
free_socket = _open_socket(host, 0)
port = free_socket.getsockname()[1]
free_socket.close()
if not is_bad_port(port):
break
return port
def http2_compatible():
# Currently, the HTTP/2.0 server is only working in python 2.7.10+ or 3.6+ and OpenSSL 1.0.2+
import ssl
ssl_v = ssl.OPENSSL_VERSION_INFO
py_v = sys.version_info
return (((py_v[0] == 2 and py_v[1] == 7 and py_v[2] >= 10) or (py_v[0] == 3 and py_v[1] >= 6)) and
(ssl_v[0] == 1 and (ssl_v[1] == 1 or (ssl_v[1] == 0 and ssl_v[2] >= 2))))

Просмотреть файл

@ -0,0 +1,33 @@
#!/usr/bin/env python
import argparse
import os
import server
def abs_path(path):
return os.path.abspath(path)
def parse_args():
parser = argparse.ArgumentParser(description="HTTP server designed for extreme flexibility "
"required in testing situations.")
parser.add_argument("document_root", action="store", type=abs_path,
help="Root directory to serve files from")
parser.add_argument("--port", "-p", dest="port", action="store",
type=int, default=8000,
help="Port number to run server on")
parser.add_argument("--host", "-H", dest="host", action="store",
type=str, default="127.0.0.1",
help="Host to run server on")
return parser.parse_args()
def main():
args = parse_args()
httpd = server.WebTestHttpd(host=args.host, port=args.port,
use_ssl=False, certificate=None,
doc_root=args.document_root)
httpd.start()
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,247 @@
"""This file provides the opening handshake processor for the Bootstrapping
WebSockets with HTTP/2 protocol (RFC 8441).
Specification:
https://tools.ietf.org/html/rfc8441
"""
from __future__ import absolute_import
from mod_pywebsocket import common
from mod_pywebsocket.stream import Stream
from mod_pywebsocket.stream import StreamOptions
from mod_pywebsocket import util
from six.moves import map
from six.moves import range
# TODO: We are using "private" methods of pywebsocket. We might want to
# refactor pywebsocket to expose those methods publicly. Also, _get_origin
# _check_version _set_protocol _parse_extensions and a large part of do_handshake
# are identical with hybi handshake. We need some refactoring to get remove that
# code duplication.
from mod_pywebsocket.extensions import get_extension_processor
from mod_pywebsocket.handshake._base import get_mandatory_header
from mod_pywebsocket.handshake._base import HandshakeException
from mod_pywebsocket.handshake._base import parse_token_list
from mod_pywebsocket.handshake._base import validate_mandatory_header
from mod_pywebsocket.handshake._base import validate_subprotocol
from mod_pywebsocket.handshake._base import VersionException
# Defining aliases for values used frequently.
_VERSION_LATEST = common.VERSION_HYBI_LATEST
_VERSION_LATEST_STRING = str(_VERSION_LATEST)
_SUPPORTED_VERSIONS = [
_VERSION_LATEST,
]
def check_connect_method(request):
if request.method != u'CONNECT':
raise HandshakeException('Method is not CONNECT: %r' % request.method)
class WsH2Handshaker(object):
def __init__(self, request, dispatcher):
"""Opening handshake processor for the WebSocket protocol (RFC 6455).
:param request: mod_python request.
:param dispatcher: Dispatcher (dispatch.Dispatcher).
WsH2Handshaker will add attributes such as ws_resource during handshake.
"""
self._logger = util.get_class_logger(self)
self._request = request
self._dispatcher = dispatcher
def _get_extension_processors_requested(self):
processors = []
if self._request.ws_requested_extensions is not None:
for extension_request in self._request.ws_requested_extensions:
processor = get_extension_processor(extension_request)
# Unknown extension requests are just ignored.
if processor is not None:
processors.append(processor)
return processors
def do_handshake(self):
self._request.ws_close_code = None
self._request.ws_close_reason = None
# Parsing.
check_connect_method(self._request)
validate_mandatory_header(self._request, ':protocol', 'websocket')
self._request.ws_resource = self._request.uri
get_mandatory_header(self._request, 'authority')
self._request.ws_version = self._check_version()
try:
self._get_origin()
self._set_protocol()
self._parse_extensions()
# Setup extension processors.
self._request.ws_extension_processors = self._get_extension_processors_requested()
# List of extra headers. The extra handshake handler may add header
# data as name/value pairs to this list and pywebsocket appends
# them to the WebSocket handshake.
self._request.extra_headers = []
# Extra handshake handler may modify/remove processors.
self._dispatcher.do_extra_handshake(self._request)
processors = [
processor
for processor in self._request.ws_extension_processors
if processor is not None
]
# Ask each processor if there are extensions on the request which
# cannot co-exist. When processor decided other processors cannot
# co-exist with it, the processor marks them (or itself) as
# "inactive". The first extension processor has the right to
# make the final call.
for processor in reversed(processors):
if processor.is_active():
processor.check_consistency_with_other_processors(
processors)
processors = [
processor for processor in processors if processor.is_active()
]
accepted_extensions = []
stream_options = StreamOptions()
for index, processor in enumerate(processors):
if not processor.is_active():
continue
extension_response = processor.get_extension_response()
if extension_response is None:
# Rejected.
continue
accepted_extensions.append(extension_response)
processor.setup_stream_options(stream_options)
# Inactivate all of the following compression extensions.
for j in range(index + 1, len(processors)):
processors[j].set_active(False)
if len(accepted_extensions) > 0:
self._request.ws_extensions = accepted_extensions
self._logger.debug(
'Extensions accepted: %r',
list(
map(common.ExtensionParameter.name,
accepted_extensions)))
else:
self._request.ws_extensions = None
self._request.ws_stream = self._create_stream(stream_options)
if self._request.ws_requested_protocols is not None:
if self._request.ws_protocol is None:
raise HandshakeException(
'do_extra_handshake must choose one subprotocol from '
'ws_requested_protocols and set it to ws_protocol')
validate_subprotocol(self._request.ws_protocol)
self._logger.debug('Subprotocol accepted: %r',
self._request.ws_protocol)
else:
if self._request.ws_protocol is not None:
raise HandshakeException(
'ws_protocol must be None when the client didn\'t '
'request any subprotocol')
self._prepare_handshake_response()
except HandshakeException as e:
if not e.status:
# Fallback to 400 bad request by default.
e.status = common.HTTP_STATUS_BAD_REQUEST
raise e
def _get_origin(self):
origin = self._request.headers_in.get('origin')
if origin is None:
self._logger.debug('Client request does not have origin header')
self._request.ws_origin = origin
def _check_version(self):
sec_websocket_version_header = 'sec-websocket-version'
version = get_mandatory_header(self._request, sec_websocket_version_header)
if version == _VERSION_LATEST_STRING:
return _VERSION_LATEST
if version.find(',') >= 0:
raise HandshakeException(
'Multiple versions (%r) are not allowed for header %s' %
(version, sec_websocket_version_header),
status=common.HTTP_STATUS_BAD_REQUEST)
raise VersionException('Unsupported version %r for header %s' %
(version, sec_websocket_version_header),
supported_versions=', '.join(
map(str, _SUPPORTED_VERSIONS)))
def _set_protocol(self):
self._request.ws_protocol = None
protocol_header = self._request.headers_in.get('sec-websocket-protocol')
if protocol_header is None:
self._request.ws_requested_protocols = None
return
self._request.ws_requested_protocols = parse_token_list(
protocol_header)
self._logger.debug('Subprotocols requested: %r',
self._request.ws_requested_protocols)
def _parse_extensions(self):
extensions_header = self._request.headers_in.get('sec-websocket-extensions')
if not extensions_header:
self._request.ws_requested_extensions = None
return
try:
self._request.ws_requested_extensions = common.parse_extensions(
extensions_header)
except common.ExtensionParsingException as e:
raise HandshakeException(
'Failed to parse sec-websocket-extensions header: %r' % e)
self._logger.debug(
'Extensions requested: %r',
list(
map(common.ExtensionParameter.name,
self._request.ws_requested_extensions)))
def _create_stream(self, stream_options):
return Stream(self._request, stream_options)
def _prepare_handshake_response(self):
self._request.status = 200
self._request.headers_out['upgrade'] = common.WEBSOCKET_UPGRADE_TYPE
self._request.headers_out['connection'] = common.UPGRADE_CONNECTION_TYPE
if self._request.ws_protocol is not None:
self._request.headers_out['sec-websocket-protocol'] = self._request.ws_protocol
if (self._request.ws_extensions is not None and
len(self._request.ws_extensions) != 0):
self._request.headers_out['sec-websocket-extensions'] = common.format_extensions(self._request.ws_extensions)
# Headers not specific for WebSocket
for name, value in self._request.extra_headers:
self._request.headers_out[name] = value

Просмотреть файл

@ -7,6 +7,7 @@ black:
- python/mozbuild/mozbuild/fork_interpose.py
- python/mozbuild/mozbuild/test/frontend/data/reader-error-syntax/moz.build
- testing/mozharness/configs/test/test_malformed.py
- testing/web-platform/mozilla/tests/tools/wptserve_py2
- testing/web-platform/tests
extensions:
- build

Просмотреть файл

@ -166,6 +166,7 @@ file-whitespace:
- testing/web-platform/tests/tools/wptrunner/wptrunner/tests/test_update.py
- testing/web-platform/tests/tools/lint/tests/dummy/broken.html
- testing/web-platform/tests/tools/lint/tests/dummy/broken_ignored.html
- testing/web-platform/mozilla/tests/tools/wptserve_py2
- toolkit/components/telemetry/build_scripts/setup.py
- toolkit/components/telemetry/tests/marionette/mach_commands.py
- toolkit/content/tests/chrome

Просмотреть файл

@ -17,6 +17,7 @@ py3:
- testing/mozharness
- testing/tps
- testing/web-platform/tests
- testing/web-platform/mozilla/tests/tools/wptserve_py2/
- toolkit
- xpcom/idl-parser
extensions: ['py']