* Pipeline renaming

* Work in progress

* Work in progress 2

* Move formData to ClientRequest

* Pure ClientRequest with no requests

* Add kwargs for send

* Pipeline update

* First pass on tests

* Some typehints

* Fixing all tests

* Full Pipeline mypy

* Py3 mock compat

* Pipeline and stream download

* First pass on Autorest testserver

* While making coverage report, don't cry for async

* Pylint

* Add absolute_import for Py2.7

* Fix ABC for 2.7

* Add absolute_import to another file for 2.7

* Pipeline is a context manager

* Some async ABC

* Move mypy to 3.6

* Fix empty policies

* aiohttp proof of concepts

* Simplify ClientResponse

* Improve response handling + async fixes

* Fix mypy

* Create a basic HTTPSender

* async dependencies

* Fix Pipfile for asyncio

* Py3.5 compat

* Add universal to async tests

* Basic requests as asyncio impl

* Remove check_redirect from configuration

* Improve default pipeline

* Make pipeline a public attribute

* Split configuration for pipeline

* Refactor config

* Restore exception

* mypy happiness

* Split requests configuration

* Simplify keep_alive behavior

* Default parameter

* Rename to on_request/on_response after feedback

* Multi-thread compatible HTTP requests sender

* Squashed commit of the following:

commit 3246847c2f
Merge: 18cb696 388e8d0
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Thu Jun 14 15:26:57 2018 -0700

    Merge remote-tracking branch 'origin/master' into async2

commit 18cb696109
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Wed May 23 11:30:00 2018 -0700

    MyPy happiness

commit bd7123396b
Merge: a997e97 3a8b79d
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Wed May 23 11:23:44 2018 -0700

    Merge branch 'master' into async2

commit a997e97cd9
Merge: 4130eca 2b7d778
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Wed May 9 13:54:11 2018 -0700

    Merge remote-tracking branch 'origin/master' into async2

commit 4130eca92a
Merge: 8ffedd8 9d81113
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Fri Apr 20 10:38:45 2018 -0700

    Merge remote-tracking branch 'origin/master' into async2

commit 8ffedd8a3a
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Mar 20 15:36:40 2018 -0700

    Refactor a little async stream download

commit bbf1259ca8
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Fri Mar 16 17:20:07 2018 -0700

    Add stream upload support

commit 2d260036f6
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Feb 27 16:25:33 2018 -0800

    Fix incorrect request call

commit 6b55d4f633
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Wed Jan 17 13:39:18 2018 -0800

    Add status/finished to async poller

commit 02c333eb13
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jan 16 15:50:06 2018 -0800

    Port stream to async implementation

commit b3f0ac7d29
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jan 16 15:32:23 2018 -0800

    Add AsyncPoller

commit 3e9e17883e
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Thu Dec 7 16:13:06 2017 -0800

    Sync ServiceClientAsync with latest fixes

commit 5483e289b5
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Thu Jul 20 11:27:05 2017 -0700

    Address feedback from @brettcannon on async

commit c99f4b71a7
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 13:14:27 2017 -0700

    Robust coverage xml report

commit e0c6d3e42b
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 12:23:12 2017 -0700

    Rename SC mixin

commit 8e029ff0be
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 12:12:12 2017 -0700

    Add async_get to paging

commit f3dfaf6526
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 12:06:20 2017 -0700

    Rename paging mixin

commit 2f7142d211
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 12:05:34 2017 -0700

    async_get_next in paging

commit 9e821009c4
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 11:20:24 2017 -0700

    async send form data

commit 17045f776e
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 11:18:46 2017 -0700

    Add async client mixin

commit 3294115452
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 11:08:37 2017 -0700

    Fix Py3.5 async tests

commit 615f672aec
Author: Laurent Mazuel <laurent.mazuel@gmail.com>
Date:   Tue Jul 18 10:31:54 2017 -0700

    Async paging with mixin

* Remove useless tox line

* Add credentials to async requests

* Add trio support

* Revamp async stream download

* Add pipeline Response wrapper

* Introduce a raw deserializer as a policy

* SansIO on_exception

* Fix deserialization tests

* Implement #116 - Env variable for UserAgent

* Revamp streaming

* Logger should not log streamable response

* Put pipeline in config

* Mypy fixes

* Update mypy 0.620

* Fix trio dep

* Create universal HTTP package

* Pipeline as a universal HTTP implementation, no mypy

* Pipeline as a universal HTTP implementation, with mypy

* Backward compatible ServiceClient

* Doc update

* Make config optional in requests engine

* Fix types for MyPY 0.630

* Credentials can be directly a Policy

* msrest[async]

* 0.6.0rc1 first ChangeLog

* Don't coverage report the TYPE_CHECKING import

* Fix dev install on 3.5 and more

* 0.6.0rc1

* Pre-load aiohttp body

* sync ServiceClient returns requersts.Response

* Fix import issue

* Remove concept of requests_kwargs

* Fix hooks parameter
This commit is contained in:
Laurent Mazuel 2018-10-02 15:35:59 -07:00 коммит произвёл GitHub
Родитель 50c5546691
Коммит 3653d29fc4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
47 изменённых файлов: 4413 добавлений и 1217 удалений

4
.coveragerc Normal file
Просмотреть файл

@ -0,0 +1,4 @@
[report]
exclude_lines =
pragma: no cover
if TYPE_CHECKING:

Просмотреть файл

@ -20,11 +20,11 @@ _autorest_install: &_autorest_install
jobs:
include:
- stage: MyPy
python: 3.5
python: 3.6
install:
- pip install mypy
script:
- mypy msrest --ignore-missing-imports
- mypy msrest
- stage: Test
python: 2.7
env: TOXENV=py27

Просмотреть файл

@ -4,12 +4,14 @@ verify_ssl = true
name = "pypi"
[packages]
"e1839a8" = {path = ".", editable = true}
"e1839a8" = {path = ".", extras = ["async"], editable = true}
[dev-packages]
pytest = "*"
pytest-cov = "*"
pytest-asyncio = {version = "*", markers="python_version >= '3.5'"}
httpretty = ">=0.8.10"
mock = {version = "*", markers="python_version <= '2.7'"}
mypy = {version = "==0.630", markers="python_version > '2.7'"}
pylint = "*"
trio = {version = "*", markers="python_version >= '3.5'"}

Просмотреть файл

@ -20,6 +20,41 @@ To install:
Release History
---------------
2018-XX-XX Version 0.6.0rc1
+++++++++++++++++++++++++++
**Features**
- The environment variable AZURE_HTTP_USER_AGENT, if present, is now injected part of the UserAgent
- New msrest.universal_http module. Provide tools to generic HTTP management (sync/async, requests/aiohttp, etc.)
- New **preview** msrest.pipeline implementation:
- A Pipeline is an ordered list of Policies than can process an HTTP request and response in a generic way.
- More details in the wiki page about Pipeline: https://github.com/Azure/msrest-for-python/wiki/msrest-0.6.0---Pipeline
- Adding new attribute to Configuration instance:
- http_logger_policy - Policy to handle HTTP logging
- user_agent_policy - Policy to handle HTTP logging
- pipeline - The current pipeline used by the SDK client
- async_pipeline - The current async pipeline used by the SDK client
- Installing "msrest[async]" now install the **experimental** async support
**Breaking changes**
- The HTTPDriver API introduced in 0.5.0 has been replaced by Pipeline.
- The following classes have been moved from "msrest.pipeline" to "msrest.universal_http":
- ClientRedirectPolicy
- ClientProxies
- ClientConnection
- The following classes have been moved from "msrest.pipeline" to "msrest.universal_http.requests":
- ClientRetryPolicy
2018-09-04 Version 0.5.5
++++++++++++++++++++++++

Просмотреть файл

@ -24,10 +24,10 @@
#
# --------------------------------------------------------------------------
from .version import msrest_version
from .configuration import Configuration
from .service_client import ServiceClient, SDKClient
from .serialization import Serializer, Deserializer
from .version import msrest_version
__all__ = [
"ServiceClient",

125
msrest/async_client.py Normal file
Просмотреть файл

@ -0,0 +1,125 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import asyncio
import functools
import logging
from typing import Any, Dict, List, Union, TYPE_CHECKING
from .universal_http import ClientRequest
from .universal_http.async_requests import AsyncRequestsHTTPSender
from .pipeline import Request, AsyncPipeline, AsyncHTTPPolicy, SansIOHTTPPolicy
from .pipeline.async_requests import (
AsyncPipelineRequestsHTTPSender,
AsyncRequestsCredentialsPolicy
)
from .pipeline.universal import (
HTTPLogger,
RawDeserializer,
)
if TYPE_CHECKING:
from .configuration import Configuration # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
class AsyncSDKClientMixin:
"""The base class of all generated SDK client.
"""
async def __aenter__(self):
await self._client.__aenter__()
return self
async def __aexit__(self, *exc_details):
await self._client.__aexit__(*exc_details)
class AsyncServiceClientMixin:
def __init__(self, creds: Any, config: 'Configuration') -> None:
# Don't do super, since I know it will be "object"
# super(AsyncServiceClientMixin, self).__init__(creds, config)
# "async_pipeline" be should accessible from "config"
# In legacy mode this is weird, this config is a parameter of "pipeline"
# Should be revamp one day.
self.config.async_pipeline = self._create_default_async_pipeline() # type: ignore
def _create_default_async_pipeline(self):
policies = [
self.config.user_agent_policy, # UserAgent policy
RawDeserializer(), # Deserialize the raw bytes
self.config.http_logger_policy # HTTP request/response log
] # type: List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]]
if self._creds:
if isinstance(self._creds, (AsyncHTTPPolicy, SansIOHTTPPolicy)):
policies.insert(1, self._creds)
else:
# Assume this is the old credentials class, and then requests. Wrap it.
policies.insert(1, AsyncRequestsCredentialsPolicy(self._creds))
return AsyncPipeline(
policies,
AsyncPipelineRequestsHTTPSender(
AsyncRequestsHTTPSender(self.config) # Send HTTP request using requests
)
)
async def __aenter__(self):
await self.config.async_pipeline.__aenter__()
return self
async def __aexit__(self, *exc_details):
await self.config.async_pipeline.__aexit__(*exc_details)
async def async_send(self, request, **kwargs):
"""Prepare and send request object according to configuration.
:param ClientRequest request: The request object to be sent.
:param dict headers: Any headers to add to the request.
:param content: Any body data to add to the request.
:param config: Any specific config overrides
"""
kwargs.setdefault('stream', True)
# In the current backward compatible implementation, return the HTTP response
# and plug context inside. Could be remove if we modify Autorest,
# but we still need it to be backward compatible
pipeline_response = await self.config.async_pipeline.run(request, **kwargs)
response = pipeline_response.http_response
response.context = pipeline_response.context
return response
def stream_download_async(self, response, user_callback):
"""Async Generator for streaming request body data.
:param response: The initial response
:param user_callback: Custom callback for monitoring progress.
"""
block = self.config.connection.data_block_size
return response.stream_download(block, user_callback)

74
msrest/async_paging.py Normal file
Просмотреть файл

@ -0,0 +1,74 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from collections.abc import AsyncIterator
import logging
_LOGGER = logging.getLogger(__name__)
class AsyncPagedMixin(AsyncIterator):
def __init__(self, *args, **kwargs):
"""Bring async to Paging.
"async_command" is mandatory keyword argument for this mixin to work.
"""
self._async_get_next = kwargs.get("async_command")
if not self._async_get_next:
_LOGGER.warning("Paging async iterator protocol is not available for %s",
self.__class__.__name__)
async def async_get(self, url):
"""Get an arbitrary page.
This resets the iterator and then fully consumes it to return the
specific page **only**.
:param str url: URL to arbitrary page results.
"""
self.reset()
self.next_link = url
return await self.async_advance_page()
async def async_advance_page(self):
if self.next_link is None:
raise StopAsyncIteration("End of paging")
self._current_page_iter_index = 0
self._response = await self._async_get_next(self.next_link)
self._derserializer(self, self._response)
return self.current_page
async def __anext__(self):
"""Iterate through responses."""
# Storing the list iterator might work out better, but there's no
# guarantee that some code won't replace the list entirely with a copy,
# invalidating an list iterator that might be saved between iterations.
if self.current_page and self._current_page_iter_index < len(self.current_page):
response = self.current_page[self._current_page_iter_index]
self._current_page_iter_index += 1
return response
else:
await self.async_advance_page()
return await self.__anext__()

Просмотреть файл

@ -31,36 +31,24 @@ try:
except ImportError:
import ConfigParser as configparser # type: ignore
from ConfigParser import NoOptionError # type: ignore
import platform
from typing import Dict, List, Any, Callable
import requests
from typing import TYPE_CHECKING, Optional, Dict, List, Any, Callable # pylint: disable=unused-import
from .exceptions import raise_with_traceback
from .pipeline import (
ClientRetryPolicy,
ClientRedirectPolicy,
ClientProxies,
ClientConnection)
from .version import msrest_version
from .pipeline import Pipeline
from .universal_http.requests import (
RequestHTTPSenderConfiguration
)
from .pipeline.universal import (
UserAgentPolicy,
HTTPLogger,
)
def default_session_configuration_callback(session, global_config, local_config, **kwargs):
# type: (requests.Session, Configuration, Dict[str,str], str) -> Dict[str, str]
"""Configuration callback if you need to change default session configuration.
if TYPE_CHECKING:
from .pipeline import AsyncPipeline
:param requests.Session session: The session.
:param Configuration global_config: The global configuration.
:param dict[str,str] local_config: The on-the-fly configuration passed on the call.
:param dict[str,str] kwargs: The current computed values for session.request method.
:return: Must return kwargs, to be passed to session.request. If None is return, initial kwargs will be used.
:rtype: dict[str,str]
"""
return kwargs
class Configuration(object):
class Configuration(RequestHTTPSenderConfiguration):
"""Client configuration.
:param str baseurl: REST API base URL.
@ -68,48 +56,28 @@ class Configuration(object):
"""
def __init__(self, base_url, filepath=None):
# type: (str, str) -> None
# type: (str, Optional[str]) -> None
super(Configuration, self).__init__(filepath)
# Service
self.base_url = base_url
# Communication configuration
self.connection = ClientConnection()
# User-Agent as a policy
self.user_agent_policy = UserAgentPolicy()
# Headers (sent with every requests)
self.headers = {} # type: Dict[str, str]
# HTTP logger policy
self.http_logger_policy = HTTPLogger()
# ProxyConfiguration
self.proxies = ClientProxies()
# The sync pipeline (will be replaced by the SDK default one, this instance if just for mypy)
self.pipeline = Pipeline() # type: Pipeline
# Retry configuration
self.retry_policy = ClientRetryPolicy()
# Redirect configuration
self.redirect_policy = ClientRedirectPolicy()
# User-Agent Header
self._user_agent = "python/{} ({}) requests/{} msrest/{}".format(
platform.python_version(),
platform.platform(),
requests.__version__,
msrest_version)
# Should we log HTTP requests/response?
self.enable_http_logger = False
# Requests hooks. Must respect requests hook callback signature
# Note that we will inject the following parameters:
# - kwargs['msrest']['session'] with the current session
self.hooks = [] # type: List[Callable[[requests.Response, str, str], None]]
self.session_configuration_callback = default_session_configuration_callback
# The async pipeline
# This is actual optional, since on 2.7 this will be None
self.async_pipeline = None # type: Optional[AsyncPipeline]
# If set to True, ServiceClient will own the sessionn
self.keep_alive = False
self._config = configparser.ConfigParser()
self._config.optionxform = str # type: ignore
if filepath:
self.load(filepath)
@ -117,7 +85,7 @@ class Configuration(object):
def user_agent(self):
# type: () -> str
"""The current user agent value."""
return self._user_agent
return self.user_agent_policy.user_agent
def add_user_agent(self, value):
# type: (str) -> None
@ -125,94 +93,12 @@ class Configuration(object):
:param str value: value to add to user agent.
"""
self._user_agent = "{} {}".format(self._user_agent, value)
self.user_agent_policy.add_user_agent(value)
def _clear_config(self):
# type: () -> None
"""Clearout config object in memory."""
for section in self._config.sections():
self._config.remove_section(section)
@property
def enable_http_logger(self):
return self.http_logger_policy.enable_http_logger
def save(self, filepath):
# type: (str) -> None
"""Save current configuration to file.
:param str filepath: Path to file where settings will be saved.
:raises: ValueError if supplied filepath cannot be written to.
"""
sections = [
"Connection",
"Proxies",
"RetryPolicy",
"RedirectPolicy"]
for section in sections:
self._config.add_section(section)
self._config.set("Connection", "base_url", self.base_url)
self._config.set("Connection", "timeout", self.connection.timeout)
self._config.set("Connection", "verify", self.connection.verify)
self._config.set("Connection", "cert", self.connection.cert)
self._config.set("Proxies", "proxies", self.proxies.proxies)
self._config.set("Proxies", "env_settings",
self.proxies.use_env_settings)
self._config.set("RetryPolicy", "retries", str(self.retry_policy.retries))
self._config.set("RetryPolicy", "backoff_factor",
str(self.retry_policy.backoff_factor))
self._config.set("RetryPolicy", "max_backoff",
str(self.retry_policy.max_backoff))
self._config.set("RedirectPolicy", "allow", self.redirect_policy.allow)
self._config.set("RedirectPolicy", "max_redirects",
self.redirect_policy.max_redirects)
try:
with open(filepath, 'w') as configfile:
self._config.write(configfile)
except (KeyError, EnvironmentError):
error = "Supplied config filepath invalid."
raise_with_traceback(ValueError, error)
finally:
self._clear_config()
def load(self, filepath):
# type: (str) -> None
"""Load configuration from existing file.
:param str filepath: Path to existing config file.
:raises: ValueError if supplied config file is invalid.
"""
try:
self._config.read(filepath)
self.base_url = \
self._config.get("Connection", "base_url")
self.connection.timeout = \
self._config.getint("Connection", "timeout")
self.connection.verify = \
self._config.getboolean("Connection", "verify")
self.connection.cert = \
self._config.get("Connection", "cert")
self.proxies.proxies = \
eval(self._config.get("Proxies", "proxies"))
self.proxies.use_env_settings = \
self._config.getboolean("Proxies", "env_settings")
self.retry_policy.retries = \
self._config.getint("RetryPolicy", "retries")
self.retry_policy.backoff_factor = \
self._config.getfloat("RetryPolicy", "backoff_factor")
self.retry_policy.max_backoff = \
self._config.getint("RetryPolicy", "max_backoff")
self.redirect_policy.allow = \
self._config.getboolean("RedirectPolicy", "allow")
self.redirect_policy.max_redirects = \
self._config.getint("RedirectPolicy", "max_redirects")
except (ValueError, EnvironmentError, NoOptionError):
error = "Supplied config file incompatible."
raise_with_traceback(ValueError, error)
finally:
self._clear_config()
@enable_http_logger.setter
def enable_http_logger(self, value):
self.http_logger_policy.enable_http_logger = value

Просмотреть файл

@ -37,7 +37,7 @@ _LOGGER = logging.getLogger(__name__)
def raise_with_traceback(exception, message="", *args, **kwargs):
# type: (Callable, str, str, str) -> None
# type: (Callable, str, Any, Any) -> None
"""Raise exception with a specified traceback.
This MUST be called inside a "except" clause.
@ -140,10 +140,13 @@ class HttpOperationError(ClientException):
def __init__(self, deserialize, response,
resp_type=None, *args, **kwargs):
# type: (Deserializer, requests.Response, Optional[str], str, str) -> None
# type: (Deserializer, Any, Optional[str], str, str) -> None
self.error = None
self.message = self._DEFAULT_MESSAGE
self.response = response
if hasattr(response, 'internal_response'):
self.response = response.internal_response
else:
self.response = response
try:
if resp_type:
self.error = deserialize(resp_type, response)
@ -168,9 +171,11 @@ class HttpOperationError(ClientException):
try:
response.raise_for_status()
# Two possible raises here:
# - Attribute error if response is not requests.RequestException. Do not catch.
# - requests.RequestException. Catch base class IOError to avoid explicit import of requests here.
except IOError as err:
# - Attribute error if response is not ClientResponse. Do not catch.
# - Any internal exception, take it.
except AttributeError:
raise
except Exception as err: # pylint: disable=broad-except
if not self.error:
self.error = err

Просмотреть файл

@ -28,16 +28,16 @@ import logging
import re
import types
from typing import Any, Union, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING # pylint: disable=unused-import
if TYPE_CHECKING:
import requests
from .universal_http import ClientRequest, ClientResponse # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
def log_request(_, request, *args, **kwargs):
# type: (Any, requests.PreparedRequest, str, str) -> None
def log_request(_, request, *_args, **_kwargs):
# type: (Any, ClientRequest, str, str) -> None
"""Log a client request.
:param _: Unused in current version (will be None)
@ -61,12 +61,12 @@ def log_request(_, request, *args, **kwargs):
_LOGGER.debug("File upload")
else:
_LOGGER.debug(str(request.body))
except Exception as err:
except Exception as err: # pylint: disable=broad-except
_LOGGER.debug("Failed to log request: %r", err)
def log_response(_, request, response, *args, **kwargs):
# type: (Any, requests.PreparedRequest, requests.Response, str, str) -> Optional[requests.Response]
def log_response(_, _request, response, *_args, **kwargs):
# type: (Any, ClientRequest, ClientResponse, str, Any) -> Optional[ClientResponse]
"""Log a server response.
:param _: Unused in current version (will be None)
@ -89,14 +89,17 @@ def log_response(_, request, response, *args, **kwargs):
if header and pattern.match(header):
filename = header.partition('=')[2]
_LOGGER.debug("File attachments: " + filename)
_LOGGER.debug("File attachments: %s", filename)
elif response.headers.get("content-type", "").endswith("octet-stream"):
_LOGGER.debug("Body contains binary data.")
elif response.headers.get("content-type", "").startswith("image"):
_LOGGER.debug("Body contains image data.")
else:
_LOGGER.debug(str(response.content))
if kwargs.get('stream', False):
_LOGGER.debug("Body is streamable")
else:
_LOGGER.debug(response.text())
return response
except Exception as err:
_LOGGER.debug("Failed to log response: " + repr(err))
except Exception as err: # pylint: disable=broad-except
_LOGGER.debug("Failed to log response: %s", repr(err))
return response

Просмотреть файл

@ -23,39 +23,51 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import sys
try:
from collections.abc import Iterator
xrange = range
except ImportError:
from collections import Iterator
from typing import Dict, Any, List, Callable, Optional, TYPE_CHECKING
from typing import Dict, Any, List, Callable, Optional, TYPE_CHECKING # pylint: disable=unused-import
if TYPE_CHECKING:
import requests
from .serialization import Deserializer, Model
from .serialization import Deserializer
from .pipeline import ClientRawResponse
if TYPE_CHECKING:
from .universal_http import ClientResponse # pylint: disable=unused-import
from .serialization import Model # pylint: disable=unused-import
class Paged(Iterator):
if sys.version_info >= (3, 5, 2):
# Not executed on old Python, no syntax error
from .async_paging import AsyncPagedMixin # type: ignore
else:
class AsyncPagedMixin(object): # type: ignore
pass
class Paged(AsyncPagedMixin, Iterator):
"""A container for paged REST responses.
:param requests.Response response: server response object.
:param ClientResponse response: server response object.
:param callable command: Function to retrieve the next page of items.
:param dict classes: A dictionary of class dependencies for
deserialization.
:param dict raw_headers: A dict of raw headers to add if "raw" is called
"""
_validation = {} # type: Dict[str, Dict[str, Any]]
_attribute_map = {} # type: Dict[str, Dict[str, Any]]
def __init__(self, command, classes, raw_headers=None):
# type: (Callable[[str], requests.Response], Dict[str, Model], Dict[str, str]) -> None
def __init__(self, command, classes, raw_headers=None, **kwargs):
# type: (Callable[[str], ClientResponse], Dict[str, Model], Dict[str, str], Any) -> None
super(Paged, self).__init__(**kwargs) # type: ignore
# Sets next_link, current_page, and _current_page_iter_index.
self.next_link = ""
self._current_page_iter_index = 0
self.reset()
self._derserializer = Deserializer(classes)
self._get_next = command
self._response = None # type: Optional[requests.Response]
self._response = None # type: Optional[ClientResponse]
self._raw_headers = raw_headers
def __iter__(self):

Просмотреть файл

@ -1,273 +0,0 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import functools
import json
import logging
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import xml.etree.ElementTree as ET
from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING
if TYPE_CHECKING:
import xml.etree.ElementTree as ET
import requests
from urllib3 import Retry # Needs requests 2.16 at least to be safe
from .serialization import Deserializer, Model
_LOGGER = logging.getLogger(__name__)
class ClientRequest(requests.Request):
"""Wrapper for requests.Request object."""
def format_parameters(self, params):
# type: (Dict[str, str]) -> None
"""Format parameters into a valid query string.
It's assumed all parameters have already been quoted as
valid URL strings.
:param dict params: A dictionary of parameters.
"""
query = urlparse(self.url).query
if query:
self.url = self.url.partition('?')[0]
existing_params = {
p[0]: p[-1]
for p in [p.partition('=') for p in query.split('&')]
}
params.update(existing_params)
query_params = ["{}={}".format(k, v) for k, v in params.items()]
query = '?' + '&'.join(query_params)
self.url = self.url + query
def add_content(self, data):
# type: (Optional[Union[Dict[str, Any], ET.Element]]) -> None
"""Add a body to the request.
:param data: Request body data, can be a json serializable
object (e.g. dictionary) or a generator (e.g. file data).
"""
if data is None:
return
if isinstance(data, ET.Element):
self.data = ET.tostring(data, encoding="utf8")
self.headers['Content-Length'] = str(len(self.data))
return
# By default, assume JSON
try:
self.data = json.dumps(data)
self.headers['Content-Length'] = str(len(self.data))
except TypeError:
self.data = data
class ClientRawResponse(object):
"""Wrapper for response object.
This allows for additional data to be gathereded from the response,
for example deserialized headers.
It also allows the raw response object to be passed back to the user.
:param output: Deserialized response object.
:param response: Raw response object.
"""
def __init__(self, output, response):
# type: (Union[Model, List[Model]], Optional[requests.Response]) -> None
self.response = response
self.output = output
self.headers = {} # type: Dict[str, Optional[Any]]
self._deserialize = Deserializer()
def add_headers(self, header_dict):
# type: (Dict[str, str]) -> None
"""Deserialize a specific header.
:param dict header_dict: A dictionary containing the name of the
header and the type to deserialize to.
"""
if not self.response:
return
for name, data_type in header_dict.items():
value = self.response.headers.get(name)
value = self._deserialize(data_type, value)
self.headers[name] = value
class ClientRetryPolicy(object):
"""Retry configuration settings.
Container for retry policy object.
"""
safe_codes = [i for i in range(500) if i != 408] + [501, 505]
def __init__(self):
self.policy = Retry()
self.policy.total = 3
self.policy.connect = 3
self.policy.read = 3
self.policy.backoff_factor = 0.8
self.policy.BACKOFF_MAX = 90
retry_codes = [i for i in range(999) if i not in self.safe_codes]
self.policy.status_forcelist = retry_codes
self.policy.method_whitelist = ['HEAD', 'TRACE', 'GET', 'PUT',
'OPTIONS', 'DELETE', 'POST', 'PATCH']
def __call__(self):
# type: () -> Retry
"""Return configuration to be applied to connection."""
debug = ("Configuring retry: max_retries=%r, "
"backoff_factor=%r, max_backoff=%r")
_LOGGER.debug(
debug, self.retries, self.backoff_factor, self.max_backoff)
return self.policy
@property
def retries(self):
# type: () -> int
"""Total number of allowed retries."""
return self.policy.total
@retries.setter
def retries(self, value):
# type: (int) -> None
self.policy.total = value
self.policy.connect = value
self.policy.read = value
@property
def backoff_factor(self):
# type: () -> Union[int, float]
"""Factor by which back-off delay is incementally increased."""
return self.policy.backoff_factor
@backoff_factor.setter
def backoff_factor(self, value):
# type: (Union[int, float]) -> None
self.policy.backoff_factor = value
@property
def max_backoff(self):
# type: () -> int
"""Max retry back-off delay."""
return self.policy.BACKOFF_MAX
@max_backoff.setter
def max_backoff(self, value):
# type: (int) -> None
self.policy.BACKOFF_MAX = value
class ClientRedirectPolicy(object):
"""Redirect configuration settings.
"""
def __init__(self):
self.allow = True
self.max_redirects = 30
def __bool__(self):
# type: () -> bool
"""Whether redirects are allowed."""
return self.allow
def __call__(self):
# type: () -> int
"""Return configuration to be applied to connection."""
debug = "Configuring redirects: allow=%r, max=%r"
_LOGGER.debug(debug, self.allow, self.max_redirects)
return self.max_redirects
def check_redirect(self, resp, request):
# type: (requests.Response, requests.PreparedRequest) -> bool
"""Whether redirect policy should be applied based on status code."""
if resp.status_code in (301, 302) and \
request.method not in ['GET', 'HEAD']:
return False
return True
class ClientProxies(object):
"""Proxy configuration settings.
Proxies can also be configured using HTTP_PROXY and HTTPS_PROXY
environment variables, in which case set use_env_settings to True.
"""
def __init__(self):
self.proxies = {}
self.use_env_settings = True
def __call__(self):
# type: () -> Dict[str, str]
"""Return configuration to be applied to connection."""
proxy_string = "\n".join(
[" {}: {}".format(k, v) for k, v in self.proxies.items()])
_LOGGER.debug("Configuring proxies: %r", proxy_string)
debug = "Evaluate proxies against ENV settings: %r"
_LOGGER.debug(debug, self.use_env_settings)
return self.proxies
def add(self, protocol, proxy_url):
# type: (str, str) -> None
"""Add proxy.
:param str protocol: Protocol for which proxy is to be applied. Can
be 'http', 'https', etc. Can also include host.
:param str proxy_url: The proxy URL. Where basic auth is required,
use the format: http://user:password@host
"""
self.proxies[protocol] = proxy_url
class ClientConnection(object):
"""Request connection configuration settings.
"""
def __init__(self):
self.timeout = 100
self.verify = True
self.cert = None
self.data_block_size = 4096
def __call__(self):
# type: () -> Dict[str, Union[str, int]]
"""Return configuration to be applied to connection."""
debug = "Configuring request: timeout=%r, verify=%r, cert=%r"
_LOGGER.debug(debug, self.timeout, self.verify, self.cert)
return {'timeout': self.timeout,
'verify': self.verify,
'cert': self.cert}

328
msrest/pipeline/__init__.py Normal file
Просмотреть файл

@ -0,0 +1,328 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
import abc
try:
import configparser
from configparser import NoOptionError
except ImportError:
import ConfigParser as configparser # type: ignore
from ConfigParser import NoOptionError # type: ignore
import json
import logging
import os.path
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import xml.etree.ElementTree as ET
from typing import TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, Tuple, Callable, Iterator # pylint: disable=unused-import
HTTPResponseType = TypeVar("HTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")
# This file is NOT using any "requests" HTTP implementation
# However, the CaseInsensitiveDict is handy.
# If one day we reach the point where "requests" can be skip totally,
# might provide our own implementation
from requests.structures import CaseInsensitiveDict
from ..exceptions import ClientRequestError, raise_with_traceback
from ..universal_http import ClientResponse
if TYPE_CHECKING:
from ..serialization import Model # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) # type: ignore
try:
from contextlib import AbstractContextManager # type: ignore
except ImportError: # Python <= 3.5
class AbstractContextManager(object): # type: ignore
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]):
"""An http policy ABC.
"""
def __init__(self):
self.next = None
@abc.abstractmethod
def send(self, request, **kwargs):
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
"""Mutate the request.
Context content is dependent of the HTTPSender.
"""
pass
class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
"""Represents a sans I/O policy.
This policy can act before the I/O, and after the I/O.
Use this policy if the actual I/O in the middle is an implementation
detail.
Context is not available, since it's implementation dependent.
if a policy needs a context of the Sender, it can't be universal.
Example: setting a UserAgent does not need to be tight to
sync or async implementation or specific HTTP lib
"""
def on_request(self, request, **kwargs):
# type: (Request[HTTPRequestType], Any) -> None
"""Is executed before sending the request to next policy.
"""
pass
def on_response(self, request, response, **kwargs):
# type: (Request[HTTPRequestType], Response[HTTPRequestType, HTTPResponseType], Any) -> None
"""Is executed after the request comes back from the policy.
"""
pass
def on_exception(self, request, **kwargs):
# type: (Request[HTTPRequestType], Any) -> bool
"""Is executed if an exception comes back fron the following
policy.
Return True if the exception has been handled and should not
be forwarded to the caller.
This method is executed inside the exception handler.
To get the exception, raise and catch it:
try:
raise
except MyError:
do_something()
or use
exc_type, exc_value, exc_traceback = sys.exc_info()
"""
return False
class _SansIOHTTPPolicyRunner(HTTPPolicy, Generic[HTTPRequestType, HTTPResponseType]):
"""Sync implementation of the SansIO policy.
"""
def __init__(self, policy):
# type: (SansIOHTTPPolicy) -> None
super(_SansIOHTTPPolicyRunner, self).__init__()
self._policy = policy
def send(self, request, **kwargs):
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
self._policy.on_request(request, **kwargs)
try:
response = self.next.send(request, **kwargs)
except Exception:
if not self._policy.on_exception(request, **kwargs):
raise
else:
self._policy.on_response(request, response, **kwargs)
return response
class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType]):
"""A pipeline implementation.
This is implemented as a context manager, that will activate the context
of the HTTP sender.
"""
def __init__(self, policies=None, sender=None):
# type: (List[Union[HTTPPolicy, SansIOHTTPPolicy]], HTTPSender) -> None
self._impl_policies = [] # type: List[HTTPPolicy]
if not sender:
# Import default only if nothing is provided
from .requests import PipelineRequestsHTTPSender
self._sender = cast(HTTPSender, PipelineRequestsHTTPSender())
else:
self._sender = sender
for policy in (policies or []):
if isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
else:
self._impl_policies.append(policy)
for index in range(len(self._impl_policies)-1):
self._impl_policies[index].next = self._impl_policies[index+1]
if self._impl_policies:
self._impl_policies[-1].next = self._sender
def __enter__(self):
# type: () -> Pipeline
self._sender.__enter__()
return self
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
self._sender.__exit__(*exc_details)
def run(self, request, **kwargs):
# type: (HTTPRequestType, Any) -> Response
context = self._sender.build_context()
pipeline_request = Request(request, context) # type: Request[HTTPRequestType]
first_node = self._impl_policies[0] if self._impl_policies else self._sender
return first_node.send(pipeline_request, **kwargs) # type: ignore
class HTTPSender(AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]):
"""An http sender ABC.
"""
@abc.abstractmethod
def send(self, request, **config):
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
"""Send the request using this HTTP sender.
"""
pass
def build_context(self):
# type: () -> Any
"""Allow the sender to build a context that will be passed
across the pipeline with the request.
Return type has no constraints. Implementation is not
required and None by default.
"""
return None
class Request(Generic[HTTPRequestType]):
"""Represents a HTTP request in a Pipeline.
URL can be given without query parameters, to be added later using "format_parameters".
Instance can be created without data, to be added later using "add_content"
Instance can be created without files, to be added later using "add_formdata"
:param str method: HTTP method (GET, HEAD, etc.)
:param str url: At least complete scheme/host/path
:param dict[str,str] headers: HTTP headers
:param files: Files list.
:param data: Body to be sent.
:type data: bytes or str.
"""
def __init__(self, http_request, context=None):
# type: (HTTPRequestType, Optional[Any]) -> None
self.http_request = http_request
self.context = context
class Response(Generic[HTTPRequestType, HTTPResponseType]):
"""A pipeline response object.
The Response interface exposes an HTTP response object as it returns through the pipeline of Policy objects.
This ensures that Policy objects have access to the HTTP response.
This also have a "context" dictionnary where policy can put additional fields.
Policy SHOULD update the "context" dictionary with additional post-processed field if they create them.
However, nothing prevents a policy to actually sub-class this class a return it instead of the initial instance.
"""
def __init__(self, request, http_response, context=None):
# type: (Request[HTTPRequestType], HTTPResponseType, Optional[Dict[str, Any]]) -> None
self.request = request
self.http_response = http_response
self.context = context or {}
# ClientRawResponse is in Pipeline for compat, but technically there is nothing Pipeline here, this is deserialization
class ClientRawResponse(object):
"""Wrapper for response object.
This allows for additional data to be gathereded from the response,
for example deserialized headers.
It also allows the raw response object to be passed back to the user.
:param output: Deserialized response object.
:param response: Raw response object.
"""
def __init__(self, output, response):
# type: (Union[Model, List[Model]], Optional[Union[Response, ClientResponse]]) -> None
from ..serialization import Deserializer
if isinstance(response, Response):
# If pipeline response, remove that layer
response = response.http_response
if isinstance(response, ClientResponse):
# If universal driver, remove that layer
self.response = response.internal_response
else:
self.response = response
self.output = output
self.headers = {} # type: Dict[str, Optional[Any]]
self._deserialize = Deserializer()
def add_headers(self, header_dict):
# type: (Dict[str, str]) -> None
"""Deserialize a specific header.
:param dict header_dict: A dictionary containing the name of the
header and the type to deserialize to.
"""
if not self.response:
return
for name, data_type in header_dict.items():
value = self.response.headers.get(name)
value = self._deserialize(data_type, value)
self.headers[name] = value
__all__ = [
'Request',
'Response',
'Pipeline',
'HTTPPolicy',
'SansIOHTTPPolicy',
'HTTPSender',
# backward compat
'ClientRawResponse',
]
try:
from .async_abc import AsyncPipeline, AsyncHTTPPolicy, AsyncHTTPSender # pylint: disable=unused-import
from .async_abc import __all__ as _async_all
__all__ += _async_all
except SyntaxError: # Python 2
pass

Просмотреть файл

@ -0,0 +1,62 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Any, Optional
from ..universal_http.aiohttp import AioHTTPSender as _AioHTTPSenderDriver
from . import AsyncHTTPSender, Request, Response
# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
class AioHTTPSender(AsyncHTTPSender):
"""AioHttp HTTP sender implementation.
"""
def __init__(self, driver: Optional[_AioHTTPSenderDriver] = None, *, loop=None) -> None:
self.driver = driver or _AioHTTPSenderDriver(loop=loop)
async def __aenter__(self):
await self.driver.__aenter__()
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
await self.driver.__aexit__(*exc_details)
def build_context(self) -> Any:
"""Allow the sender to build a context that will be passed
across the pipeline with the request.
Return type has no constraints. Implementation is not
required and None by default.
"""
return None
async def send(self, request: Request, **config: Any) -> Response:
"""Send the request using this HTTP sender.
"""
return Response(
request,
await self.driver.send(request.http_request)
)

Просмотреть файл

@ -0,0 +1,165 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import abc
from typing import Any, List, Union, Callable, AsyncIterator, Optional, Generic, TypeVar
from . import Request, Response, Pipeline, SansIOHTTPPolicy
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
HTTPRequestType = TypeVar("HTTPRequestType")
try:
from contextlib import AbstractAsyncContextManager # type: ignore
except ImportError: # Python <= 3.7
class AbstractAsyncContextManager(object): # type: ignore
async def __aenter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
"""An http policy ABC.
"""
def __init__(self) -> None:
# next will be set once in the pipeline
self.next = None # type: Optional[Union[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], AsyncHTTPSender[HTTPRequestType, AsyncHTTPResponseType]]]
@abc.abstractmethod
async def send(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
"""Mutate the request.
Context content is dependent of the HTTPSender.
"""
pass
class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
"""Async implementation of the SansIO policy.
"""
def __init__(self, policy: SansIOHTTPPolicy) -> None:
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
self._policy = policy
async def send(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
self._policy.on_request(request, **kwargs)
try:
response = await self.next.send(request, **kwargs) # type: ignore
except Exception:
if not self._policy.on_exception(request, **kwargs):
raise
else:
self._policy.on_response(request, response, **kwargs)
return response
class AsyncHTTPSender(AbstractAsyncContextManager, abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
"""An http sender ABC.
"""
@abc.abstractmethod
async def send(self, request: Request[HTTPRequestType], **config: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
"""Send the request using this HTTP sender.
"""
pass
def build_context(self) -> Any:
"""Allow the sender to build a context that will be passed
across the pipeline with the request.
Return type has no constraints. Implementation is not
required and None by default.
"""
return None
def __enter__(self):
raise TypeError("Use async with instead")
def __exit__(self, exc_type, exc_val, exc_tb):
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
class AsyncPipeline(AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType]):
"""A pipeline implementation.
This is implemented as a context manager, that will activate the context
of the HTTP sender.
"""
def __init__(self, policies: List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]] = None, sender: Optional[AsyncHTTPSender[HTTPRequestType, AsyncHTTPResponseType]] = None) -> None:
self._impl_policies = [] # type: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]]
if sender:
self._sender = sender
else:
# Import default only if nothing is provided
from .aiohttp import AioHTTPSender
self._sender = AioHTTPSender()
for policy in (policies or []):
if isinstance(policy, SansIOHTTPPolicy):
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
else:
self._impl_policies.append(policy)
for index in range(len(self._impl_policies)-1):
self._impl_policies[index].next = self._impl_policies[index+1]
if self._impl_policies:
self._impl_policies[-1].next = self._sender
def __enter__(self):
raise TypeError("Use 'async with' instead")
def __exit__(self, exc_type, exc_val, exc_tb):
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> 'AsyncPipeline':
await self._sender.__aenter__()
return self
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
await self._sender.__aexit__(*exc_details)
async def run(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
context = self._sender.build_context()
pipeline_request = Request(request, context)
first_node = self._impl_policies[0] if self._impl_policies else self._sender
return await first_node.send(pipeline_request, **kwargs) # type: ignore
__all__ = [
'AsyncHTTPPolicy',
'AsyncHTTPSender',
'AsyncPipeline',
]

Просмотреть файл

@ -0,0 +1,129 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import asyncio
from collections.abc import AsyncIterator
import functools
import logging
from typing import Any, Callable, Optional, AsyncIterator as AsyncIteratorType
from oauthlib import oauth2
import requests
from requests.models import CONTENT_CHUNK_SIZE
from ..exceptions import (
TokenExpiredError,
ClientRequestError,
raise_with_traceback
)
from ..universal_http.async_requests import AsyncBasicRequestsHTTPSender
from . import AsyncHTTPSender, AsyncHTTPPolicy, Response, Request
from .requests import RequestsContext
_LOGGER = logging.getLogger(__name__)
class AsyncPipelineRequestsHTTPSender(AsyncHTTPSender):
"""Implements a basic Pipeline, that supports universal HTTP lib "requests" driver.
"""
def __init__(self, universal_http_requests_driver: Optional[AsyncBasicRequestsHTTPSender]=None) -> None:
self.driver = universal_http_requests_driver or AsyncBasicRequestsHTTPSender()
async def __aenter__(self) -> 'AsyncPipelineRequestsHTTPSender':
await self.driver.__aenter__()
return self
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
await self.driver.__aexit__(*exc_details)
async def close(self):
await self.__aexit__()
def build_context(self):
# type: () -> RequestsContext
return RequestsContext(
session=self.driver.session,
)
async def send(self, request: Request, **kwargs) -> Response:
"""Send request object according to configuration.
:param Request request: The request object to be sent.
"""
if request.context is None: # Should not happen, but make mypy happy and does not hurt
request.context = self.build_context()
if request.context.session is not self.driver.session:
kwargs['session'] = request.context.session
return Response(
request,
await self.driver.send(request.http_request, **kwargs)
)
class AsyncRequestsCredentialsPolicy(AsyncHTTPPolicy):
"""Implementation of request-oauthlib except and retry logic.
"""
def __init__(self, credentials):
super(AsyncRequestsCredentialsPolicy, self).__init__()
self._creds = credentials
async def send(self, request, **kwargs):
session = request.context.session
try:
self._creds.signed_session(session)
except TypeError: # Credentials does not support session injection
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
request.context.session = session = self._creds.signed_session()
try:
try:
return await self.next.send(request, **kwargs)
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
error = "Token expired or is invalid. Attempting to refresh."
_LOGGER.warning(error)
try:
try:
self._creds.refresh_session(session)
except TypeError: # Credentials does not support session injection
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
request.context.session = session = self._creds.refresh_session()
return await self.next.send(request, **kwargs)
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
msg = "Token expired or is invalid."
raise_with_traceback(TokenExpiredError, msg, err)
except (requests.RequestException,
oauth2.rfc6749.errors.OAuth2Error) as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)

194
msrest/pipeline/requests.py Normal file
Просмотреть файл

@ -0,0 +1,194 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
"""
This module is the requests implementation of Pipeline ABC
"""
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
import contextlib
import logging
import threading
from typing import TYPE_CHECKING, List, Callable, Iterator, Any, Union, Dict, Optional # pylint: disable=unused-import
import warnings
from oauthlib import oauth2
import requests
from requests.models import CONTENT_CHUNK_SIZE
from urllib3 import Retry # Needs requests 2.16 at least to be safe
from ..exceptions import (
TokenExpiredError,
ClientRequestError,
raise_with_traceback
)
from ..universal_http import ClientRequest
from ..universal_http.requests import BasicRequestsHTTPSender
from . import HTTPSender, HTTPPolicy, Response, Request
_LOGGER = logging.getLogger(__name__)
class RequestsCredentialsPolicy(HTTPPolicy):
"""Implementation of request-oauthlib except and retry logic.
"""
def __init__(self, credentials):
super(RequestsCredentialsPolicy, self).__init__()
self._creds = credentials
def send(self, request, **kwargs):
session = request.context.session
try:
self._creds.signed_session(session)
except TypeError: # Credentials does not support session injection
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
request.context.session = session = self._creds.signed_session()
try:
try:
return self.next.send(request, **kwargs)
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
error = "Token expired or is invalid. Attempting to refresh."
_LOGGER.warning(error)
try:
try:
self._creds.refresh_session(session)
except TypeError: # Credentials does not support session injection
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
request.context.session = session = self._creds.refresh_session()
return self.next.send(request, **kwargs)
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
msg = "Token expired or is invalid."
raise_with_traceback(TokenExpiredError, msg, err)
except (requests.RequestException,
oauth2.rfc6749.errors.OAuth2Error) as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)
class RequestsPatchSession(HTTPPolicy):
"""Implements request level configuration
that are actually to be done at the session level.
This is highly deprecated, and is totally legacy.
The pipeline structure allows way better design for this.
"""
_protocols = ['http://', 'https://']
def send(self, request, **kwargs):
"""Patch the current session with Request level operation config.
This is deprecated, we shouldn't patch the session with
arguments at the Request, and "config" should be used.
"""
session = request.context.session
old_max_redirects = None
if 'max_redirects' in kwargs:
warnings.warn("max_redirects in operation kwargs is deprecated, use config.redirect_policy instead",
DeprecationWarning)
old_max_redirects = session.max_redirects
session.max_redirects = int(kwargs['max_redirects'])
old_trust_env = None
if 'use_env_proxies' in kwargs:
warnings.warn("use_env_proxies in operation kwargs is deprecated, use config.proxies instead",
DeprecationWarning)
old_trust_env = session.trust_env
session.trust_env = bool(kwargs['use_env_proxies'])
old_retries = {}
if 'retries' in kwargs:
warnings.warn("retries in operation kwargs is deprecated, use config.retry_policy instead",
DeprecationWarning)
max_retries = kwargs['retries']
for protocol in self._protocols:
old_retries[protocol] = session.adapters[protocol].max_retries
session.adapters[protocol].max_retries = max_retries
try:
return self.next.send(request, **kwargs)
finally:
if old_max_redirects:
session.max_redirects = old_max_redirects
if old_trust_env:
session.trust_env = old_trust_env
if old_retries:
for protocol in self._protocols:
session.adapters[protocol].max_retries = old_retries[protocol]
class RequestsContext(object):
def __init__(self, session):
self.session = session
class PipelineRequestsHTTPSender(HTTPSender):
"""Implements a basic Pipeline, that supports universal HTTP lib "requests" driver.
"""
def __init__(self, universal_http_requests_driver=None):
# type: (Optional[BasicRequestsHTTPSender]) -> None
self.driver = universal_http_requests_driver or BasicRequestsHTTPSender()
def __enter__(self):
# type: () -> PipelineRequestsHTTPSender
self.driver.__enter__()
return self
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
self.driver.__exit__(*exc_details)
def close(self):
self.__exit__()
def build_context(self):
# type: () -> RequestsContext
return RequestsContext(
session=self.driver.session,
)
def send(self, request, **kwargs):
# type: (Request[ClientRequest], Any) -> Response
"""Send request object according to configuration.
:param Request request: The request object to be sent.
"""
if request.context is None: # Should not happen, but make mypy happy and does not hurt
request.context = self.build_context()
if request.context.session is not self.driver.session:
kwargs['session'] = request.context.session
return Response(
request,
self.driver.send(request.http_request, **kwargs)
)

Просмотреть файл

@ -0,0 +1,239 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
"""
This module represents universal policy that works whatever the HTTPSender implementation
"""
import json
import logging
import os
import xml.etree.ElementTree as ET
import platform
from typing import Mapping, Any, Optional, AnyStr, Union, IO, cast, TYPE_CHECKING # pylint: disable=unused-import
from ..version import msrest_version as _msrest_version
from . import SansIOHTTPPolicy
from ..exceptions import DeserializationError, raise_with_traceback
from ..http_logger import log_request, log_response
if TYPE_CHECKING:
from . import Request, Response # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
class HeadersPolicy(SansIOHTTPPolicy):
"""A simple policy that sends the given headers
with the request.
This overwrite any headers already defined in the request.
"""
def __init__(self, headers):
# type: (Mapping[str, str]) -> None
self.headers = headers
def on_request(self, request, **kwargs):
# type: (Request, Any) -> None
http_request = request.http_request
http_request.headers.update(self.headers)
class UserAgentPolicy(SansIOHTTPPolicy):
_USERAGENT = "User-Agent"
_ENV_ADDITIONAL_USER_AGENT = 'AZURE_HTTP_USER_AGENT'
def __init__(self, user_agent=None, overwrite=False):
# type: (Optional[str], bool) -> None
self._overwrite = overwrite
if user_agent is None:
self._user_agent = "python/{} ({}) msrest/{}".format(
platform.python_version(),
platform.platform(),
_msrest_version
)
else:
self._user_agent = user_agent
# Whatever you gave me a header explicitly or not,
# if the env variable is set, add to it.
add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None)
if add_user_agent_header is not None:
self.add_user_agent(add_user_agent_header)
@property
def user_agent(self):
# type: () -> str
"""The current user agent value."""
return self._user_agent
def add_user_agent(self, value):
# type: (str) -> None
"""Add value to current user agent with a space.
:param str value: value to add to user agent.
"""
self._user_agent = "{} {}".format(self._user_agent, value)
def on_request(self, request, **kwargs):
# type: (Request, Any) -> None
http_request = request.http_request
if self._overwrite or self._USERAGENT not in http_request.headers:
http_request.headers[self._USERAGENT] = self._user_agent
class HTTPLogger(SansIOHTTPPolicy):
"""A policy that logs HTTP request and response to the DEBUG logger.
This accepts both global configuration, and kwargs request level with "enable_http_logger"
"""
def __init__(self, enable_http_logger = False):
self.enable_http_logger = enable_http_logger
def on_request(self, request, **kwargs):
# type: (Request, Any) -> None
http_request = request.http_request
if kwargs.get("enable_http_logger", self.enable_http_logger):
log_request(None, http_request)
def on_response(self, request, response, **kwargs):
# type: (Request, Response, Any) -> None
http_request = request.http_request
if kwargs.get("enable_http_logger", self.enable_http_logger):
log_response(None, http_request, response.http_response, result=response)
class RawDeserializer(SansIOHTTPPolicy):
JSON_MIMETYPES = [
'application/json',
'text/json' # Because we're open minded people...
]
# Name used in context
CONTEXT_NAME = "deserialized_data"
@classmethod
def deserialize_from_text(cls, data, content_type=None):
# type: (Optional[Union[AnyStr, IO]], Optional[str]) -> Any
"""Decode data according to content-type.
Accept a stream of data as well, but will be load at once in memory for now.
If no content-type, will return the string version (not bytes, not stream)
:param data: Input, could be bytes or stream (will be decoded with UTF8) or text
:type data: str or bytes or IO
:param str content_type: The content type.
"""
if hasattr(data, 'read'):
# Assume a stream
data = cast(IO, data).read()
if isinstance(data, bytes):
data_as_str = data.decode(encoding='utf-8-sig')
else:
# Explain to mypy the correct type.
data_as_str = cast(str, data)
if content_type is None:
return data
if content_type in cls.JSON_MIMETYPES:
try:
return json.loads(data_as_str)
except ValueError as err:
raise DeserializationError("JSON is invalid: {}".format(err), err)
elif "xml" in (content_type or []):
try:
return ET.fromstring(data_as_str)
except ET.ParseError:
# It might be because the server has an issue, and returned JSON with
# content-type XML....
# So let's try a JSON load, and if it's still broken
# let's flow the initial exception
def _json_attemp(data):
try:
return True, json.loads(data)
except ValueError:
return False, None # Don't care about this one
success, json_result = _json_attemp(data)
if success:
return json_result
# If i'm here, it's not JSON, it's not XML, let's scream
# and raise the last context in this block (the XML exception)
# The function hack is because Py2.7 messes up with exception
# context otherwise.
_LOGGER.critical("Wasn't XML not JSON, failing")
raise_with_traceback(DeserializationError, "XML is invalid")
raise DeserializationError("Cannot deserialize content-type: {}".format(content_type))
@classmethod
def deserialize_from_http_generics(cls, body_bytes, headers):
# type: (Optional[Union[AnyStr, IO]], Mapping) -> Any
"""Deserialize from HTTP response.
Use bytes and headers to NOT use any requests/aiohttp or whatever
specific implementation.
Headers will tested for "content-type"
"""
# Try to use content-type from headers if available
content_type = None
if 'content-type' in headers:
content_type = headers['content-type'].split(";")[0].strip().lower()
# Ouch, this server did not declare what it sent...
# Let's guess it's JSON...
# Also, since Autorest was considering that an empty body was a valid JSON,
# need that test as well....
else:
content_type = "application/json"
if body_bytes:
return cls.deserialize_from_text(body_bytes, content_type)
return None
def on_response(self, request, response, **kwargs):
# type: (Request, Response, Any) -> None
"""Extract data from the body of a REST response object.
This will load the entire payload in memory.
Will follow Content-Type to parse.
We assume everything is UTF8 (BOM acceptable).
:param raw_data: Data to be processed.
:param content_type: How to parse if raw_data is a string/bytes.
:raises JSONDecodeError: If JSON is requested and parsing is impossible.
:raises UnicodeDecodeError: If bytes is not UTF8
:raises xml.etree.ElementTree.ParseError: If bytes is not valid XML
"""
# If response was asked as stream, do NOT read anything and quit now
if kwargs.get("stream", True):
return
http_response = response.http_response
response.context[self.CONTEXT_NAME] = self.deserialize_from_http_generics(
http_response.text(),
http_response.headers
)

Просмотреть файл

@ -23,6 +23,12 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from .poller import LROPoller, NoPolling, PollingMethod
import sys
__all__ = ['LROPoller', 'NoPolling', 'PollingMethod']
from .poller import LROPoller, NoPolling, PollingMethod
__all__ = ['LROPoller', 'NoPolling', 'PollingMethod']
if sys.version_info >= (3, 5, 2):
# Not executed on old Python, no syntax error
from .async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller']

Просмотреть файл

@ -0,0 +1,88 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from .poller import NoPolling as _NoPolling
from ..serialization import Model
from ..service_client import ServiceClient
from ..pipeline import ClientRawResponse
class AsyncPollingMethod(object):
"""ABC class for polling method.
"""
def initialize(self, client, initial_response, deserialization_callback):
raise NotImplementedError("This method needs to be implemented")
async def run(self):
raise NotImplementedError("This method needs to be implemented")
def status(self):
raise NotImplementedError("This method needs to be implemented")
def finished(self):
raise NotImplementedError("This method needs to be implemented")
def resource(self):
raise NotImplementedError("This method needs to be implemented")
class AsyncNoPolling(_NoPolling):
"""An empty async poller that returns the deserialized initial response.
"""
async def run(self):
"""Empty run, no polling.
Just override initial run to add "async"
"""
pass
async def async_poller(client, initial_response, deserialization_callback, polling_method):
"""Async Poller for long running operations.
:param client: A msrest service client. Can be a SDK client and it will be casted to a ServiceClient.
:type client: msrest.service_client.ServiceClient
:param initial_response: The initial call response
:type initial_response: msrest.universal_http.ClientResponse or msrest.pipeline.ClientRawResponse
:param deserialization_callback: A callback that takes a Response and return a deserialized object. If a subclass of Model is given, this passes "deserialize" as callback.
:type deserialization_callback: callable or msrest.serialization.Model
:param polling_method: The polling strategy to adopt
:type polling_method: msrest.polling.PollingMethod
"""
try:
client = client if isinstance(client, ServiceClient) else client._client
except AttributeError:
raise ValueError("Poller client parameter must be a low-level msrest Service Client or a SDK client.")
response = initial_response.response if isinstance(initial_response, ClientRawResponse) else initial_response
if isinstance(deserialization_callback, type) and issubclass(deserialization_callback, Model):
deserialization_callback = deserialization_callback.deserialize
# Might raise a CloudError
polling_method.initialize(client, response, deserialization_callback)
await polling_method.run()
return polling_method.resource()

Просмотреть файл

@ -43,6 +43,8 @@ import isodate
from typing import Dict, Any
from .pipeline.universal import RawDeserializer
from .exceptions import (
ValidationError,
SerializationError,
@ -1335,77 +1337,47 @@ class Deserializer(object):
pass # Target is not a Model, no classify
return target, target.__class__.__name__
JSON_MIMETYPES = [
'application/json',
'text/json' # Because we're open minded people...
]
@staticmethod
def _unpack_content(raw_data, content_type=None):
"""Extract data from the body of a REST response object.
"""Extract the correct structure for deserialization.
If raw_data is a requests.Response object, follow Content-Type
to parse (ignore content_type parameter).
If bytes is given, decode using UTF8 first.
If content_type is given, try to parse.
Otherwise, return initial data.
We assume everything is UTF8 (BOM acceptable).
If raw_data is a PipelineResponse, try to extract the result of RawDeserializer.
if we can't, raise. Your Pipeline should have a RawDeserializer.
If not a pipeline response and raw_data is bytes or string, use content-type
to decode it. If no content-type, try JSON.
If raw_data is something else, bypass all logic and return it directly.
:param raw_data: Data to be processed.
:param content_type: How to parse if raw_data is a string/bytes.
:raises JSONDecodeError: If JSON is requested and parsing is impossible.
:raises UnicodeDecodeError: If bytes is not UTF8
"""
# Assume this is enough to detect a Pipeline Response without importing it
context = getattr(raw_data, "context", {})
if context:
if RawDeserializer.CONTEXT_NAME in context:
return context[RawDeserializer.CONTEXT_NAME]
raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize")
if hasattr(raw_data, 'text'): # Our requests.Response test
# Try to use content-type from headers if available
if 'content-type' in raw_data.headers:
content_type = raw_data.headers['content-type'].split(";")[0].strip().lower()
# Ouch, this server did not declare what it sent...
# Use Swagger "produces", which will be passed to "content_type" here
# If "content_type" also is empty, this means that it's an old version
# of Autorest for Python, let's guess it's JSON...
# Also, since Autorest was considering that an empty body was a valid JSON,
# need that test as well....
elif not content_type:
if not raw_data.text:
return None
content_type = "application/json"
# Whatever content type, data is readable (not bytes). Get it as a string.
data = raw_data.text
elif raw_data and isinstance(raw_data, bytes):
data = raw_data.decode(encoding='utf-8-sig')
else:
data = raw_data
#Assume this is enough to recognize universal_http.ClientResponse without importing it
if hasattr(raw_data, "body"):
return RawDeserializer.deserialize_from_http_generics(
raw_data.text(),
raw_data.headers
)
if content_type in Deserializer.JSON_MIMETYPES:
try:
return json.loads(data)
except ValueError as err:
raise DeserializationError("JSON is invalid: {}".format(err), err)
elif "xml" in (content_type or []):
try:
return ET.fromstring(data)
except ET.ParseError:
# It might be because the server has an issue, and returned JSON with
# content-type XML....
# So let's try a JSON load, and if it's still broken
# let's flow the initial exception
def _json_attemp(data):
try:
return True, json.loads(data)
except ValueError:
return False, None # Don't care about this one
success, data = _json_attemp(data)
if success:
return data
# If i'm here, it's not JSON, it's not XML, let's scream
# and raise the last context in this block (the XML exception)
# The function hack is because Py2.7 messes up with exception
# context otherwise.
_LOGGER.critical("Wasn't XML not JSON, failing")
raise
return data
# Assume this enough to recognize requests.Response without importing it.
if hasattr(raw_data, '_content_consumed'):
return RawDeserializer.deserialize_from_http_generics(
raw_data.text,
raw_data.headers
)
if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'):
return RawDeserializer.deserialize_from_text(raw_data, content_type)
return raw_data
def _instantiate_model(self, response, attrs, additional_properties=None):
"""Instantiate a response model passing in deserialized args.

Просмотреть файл

@ -24,35 +24,54 @@
#
# --------------------------------------------------------------------------
import contextlib
import logging
import os
import sys
try:
from urlparse import urljoin, urlparse
except ImportError:
from urllib.parse import urljoin, urlparse
import warnings
from typing import Any, Dict, Union, IO, Tuple, Optional, cast, TYPE_CHECKING
if TYPE_CHECKING:
from .configuration import Configuration
from oauthlib import oauth2
import requests.adapters
from typing import List, Any, Dict, Union, IO, Tuple, Optional, Callable, Iterator, cast, TYPE_CHECKING # pylint: disable=unused-import
from .authentication import Authentication
from .pipeline import ClientRequest
from .http_logger import log_request, log_response
from .exceptions import (
TokenExpiredError,
ClientRequestError,
raise_with_traceback)
from .universal_http import ClientRequest, ClientResponse
from .universal_http.requests import (
RequestsHTTPSender,
)
from .pipeline import Request, Pipeline, HTTPPolicy, SansIOHTTPPolicy
from .pipeline.requests import (
PipelineRequestsHTTPSender,
RequestsCredentialsPolicy,
RequestsPatchSession
)
from .pipeline.universal import (
HTTPLogger,
RawDeserializer
)
if TYPE_CHECKING:
from .configuration import Configuration # pylint: disable=unused-import
from .universal_http.requests import RequestsClientResponse # pylint: disable=unused-import
import requests # pylint: disable=unused-import
if sys.version_info >= (3, 5, 2):
# Not executed on old Python, no syntax error
from .async_client import AsyncServiceClientMixin, AsyncSDKClientMixin # type: ignore
else:
class AsyncSDKClientMixin(object): # type: ignore
pass
class AsyncServiceClientMixin(object): # type: ignore
def __init__(self, creds, config):
pass
_LOGGER = logging.getLogger(__name__)
class SDKClient(object):
class SDKClient(AsyncSDKClientMixin):
"""The base class of all generated SDK client.
"""
def __init__(self, creds, config):
@ -73,121 +92,8 @@ class SDKClient(object):
def __exit__(self, *exc_details):
self._client.__exit__(*exc_details)
class _RequestsHTTPDriver(object):
_protocols = ['http://', 'https://']
def __init__(self, config):
# type: (Configuration) -> None
self.config = config
self.session = requests.Session()
def __enter__(self):
# type: () -> _RequestsHTTPDriver
return self
def __exit__(self, *exc_details):
self.close()
def close(self):
self.session.close()
def configure_session(self, **config):
# type: (str) -> Dict[str, Any]
"""Apply configuration to session.
:param config: Specific configuration overrides.
:rtype: dict
:return: A dict that will be kwarg-send to session.request
"""
kwargs = self.config.connection() # type: Dict[str, Any]
for opt in ['timeout', 'verify', 'cert']:
kwargs[opt] = config.get(opt, kwargs[opt])
kwargs.update({k:config[k] for k in ['cookies'] if k in config})
kwargs['allow_redirects'] = config.get(
'allow_redirects', bool(self.config.redirect_policy))
kwargs['headers'] = self.config.headers.copy()
kwargs['headers']['User-Agent'] = self.config.user_agent
proxies = config.get('proxies', self.config.proxies())
if proxies:
kwargs['proxies'] = proxies
kwargs['stream'] = config.get('stream', True)
self.session.max_redirects = int(config.get('max_redirects', self.config.redirect_policy()))
self.session.trust_env = bool(config.get('use_env_proxies', self.config.proxies.use_env_settings))
# Patch the redirect method directly *if not done already*
if not getattr(self.session.resolve_redirects, 'is_mrest_patched', False):
redirect_logic = self.session.resolve_redirects
def wrapped_redirect(resp, req, **kwargs):
attempt = self.config.redirect_policy.check_redirect(resp, req)
return redirect_logic(resp, req, **kwargs) if attempt else []
wrapped_redirect.is_mrest_patched = True # type: ignore
self.session.resolve_redirects = wrapped_redirect # type: ignore
# if "enable_http_logger" is defined at the operation level, take the value.
# if not, take the one in the client config
# if not, disable http_logger
hooks = []
if config.get("enable_http_logger", self.config.enable_http_logger):
def log_hook(r, *args, **kwargs):
log_request(None, r.request)
log_response(None, r.request, r, result=r)
hooks.append(log_hook)
def make_user_hook_cb(user_hook, session):
def user_hook_cb(r, *args, **kwargs):
kwargs.setdefault("msrest", {})['session'] = session
return user_hook(r, *args, **kwargs)
return user_hook_cb
for user_hook in self.config.hooks:
hooks.append(make_user_hook_cb(user_hook, self.session))
if hooks:
kwargs['hooks'] = {'response': hooks}
# Change max_retries in current all installed adapters
max_retries = config.get('retries', self.config.retry_policy())
for protocol in self._protocols:
self.session.adapters[protocol].max_retries=max_retries
output_kwargs = self.config.session_configuration_callback(
self.session,
self.config,
config,
**kwargs
)
if output_kwargs is not None:
kwargs = output_kwargs
return kwargs
def send(self, request, **config):
# type: (ClientRequest, Any) -> requests.Response
"""Send request object according to configuration.
:param ClientRequest request: The request object to be sent.
:param config: Any specific config overrides
"""
kwargs = config.copy()
if request.files:
kwargs['files'] = request.files
elif request.data:
kwargs['data'] = request.data
kwargs.setdefault("headers", {}).update(request.headers)
response = self.session.request(
request.method,
request.url,
**kwargs)
return response
class ServiceClient(object):
class ServiceClient(AsyncServiceClientMixin):
"""REST Service Client.
Maintains client pipeline and handles all requests and responses.
@ -197,47 +103,56 @@ class ServiceClient(object):
def __init__(self, creds, config):
# type: (Any, Configuration) -> None
if config is None:
raise ValueError("Config is a required parameter")
self.config = config
self.creds = creds if creds else Authentication()
self._http_driver = _RequestsHTTPDriver(config)
self._creds = creds
# Call the mixin AFTER self.config and self._creds
super(ServiceClient, self).__init__(creds, config)
# "pipeline" be should accessible from "config"
# In legacy mode this is weird, this config is a parameter of "pipeline"
# Should be revamp one day.
self.config.pipeline = self._create_default_pipeline()
def _create_default_pipeline(self):
# type: () -> Pipeline[ClientRequest, RequestsClientResponse]
policies = [
self.config.user_agent_policy, # UserAgent policy
RequestsPatchSession(), # Support deprecated operation config at the session level
self.config.http_logger_policy # HTTP request/response log
] # type: List[Union[HTTPPolicy, SansIOHTTPPolicy]]
if self._creds:
if isinstance(self._creds, (HTTPPolicy, SansIOHTTPPolicy)):
policies.insert(1, self._creds)
else:
# Assume this is the old credentials class, and then requests. Wrap it.
policies.insert(1, RequestsCredentialsPolicy(self._creds)) # Set credentials for requests based session
return Pipeline(
policies,
PipelineRequestsHTTPSender(RequestsHTTPSender(self.config)) # Send HTTP request using requests
)
def __enter__(self):
# type: () -> ServiceClient
self.config.keep_alive = True
self._http_driver.__enter__()
self.config.pipeline.__enter__()
return self
def __exit__(self, *exc_details):
self._http_driver.__exit__(*exc_details)
self.config.pipeline.__exit__(*exc_details)
self.config.keep_alive = False
def close(self):
# type: () -> None
"""Close the session if keep_alive is True.
"""Close the pipeline if keep_alive is True.
"""
self._http_driver.close()
self.config.pipeline.__exit__()
def _format_data(self, data):
# type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]
"""Format field data according to whether it is a stream or
a string for a form-data request.
:param data: The request field data.
:type data: str or file-like object.
"""
if hasattr(data, 'read'):
data = cast(IO, data)
data_name = None
try:
if data.name[0] != '<' and data.name[-1] != '>':
data_name = os.path.basename(data.name)
except (AttributeError, TypeError):
pass
return (data_name, data, "application/octet-stream")
return (None, cast(str, data))
def _request(self, url, params, headers, content, form_content):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def _request(self, method, url, params, headers, content, form_content):
# type: (str, str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create ClientRequest object.
:param str url: URL for the request.
@ -245,10 +160,7 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = ClientRequest()
if url:
request.url = self.format_url(url)
request = ClientRequest(method, self.format_url(url))
if params:
request.format_parameters(params)
@ -266,31 +178,10 @@ class ServiceClient(object):
request.add_content(content)
if form_content:
self._add_formdata(request, form_content)
request.add_formdata(form_content)
return request
def _add_formdata(self, request, content=None):
# type: (ClientRequest, Optional[Dict[str, str]]) -> None
"""Add data as a multipart form-data request to the request.
We only deal with file-like objects or strings at this point.
The requests is not yet streamed.
:param ClientRequest request: The request object to be sent.
:param dict headers: Any headers to add to the request.
:param dict content: Dictionary of the fields of the formdata.
"""
if content is None:
content = {}
content_type = request.headers.pop('Content-Type', None) if request.headers else None
if content_type and content_type.lower() == 'application/x-www-form-urlencoded':
# Do NOT use "add_content" that assumes input is JSON
request.data = {f: d for f, d in content.items() if d is not None}
else: # Assume "multipart/form-data"
request.files = {f: self._format_data(d) for f, d in content.items() if d is not None}
def send_formdata(self, request, headers=None, content=None, **config):
"""Send data as a multipart form-data request.
We only deal with file-like objects or strings at this point.
@ -304,10 +195,10 @@ class ServiceClient(object):
:param config: Any specific config overrides.
"""
request.headers = headers
self._add_formdata(request, content)
request.add_formdata(content)
return self.send(request, **config)
def send(self, request, headers=None, content=None, **config):
def send(self, request, headers=None, content=None, **kwargs):
"""Prepare and send request object according to configuration.
:param ClientRequest request: The request object to be sent.
@ -315,92 +206,54 @@ class ServiceClient(object):
:param content: Any body data to add to the request.
:param config: Any specific config overrides
"""
if self.config.keep_alive:
http_driver = self._http_driver
else:
http_driver = _RequestsHTTPDriver(self.config)
try:
self.creds.signed_session(http_driver.session)
except TypeError: # Credentials does not support session injection
http_driver.session = self.creds.signed_session()
if http_driver is self._http_driver:
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
kwargs = http_driver.configure_session(**config)
# "content" and "headers" are deprecated, only old SDK
if headers:
request.headers.update(headers)
if not request.files and request.data == [] and content is not None:
if not request.files and request.data is None and content is not None:
request.add_content(content)
# End of deprecation
response = None
kwargs.setdefault('stream', True)
try:
try:
response = http_driver.send(request, **kwargs)
return response
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
error = "Token expired or is invalid. Attempting to refresh."
_LOGGER.warning(error)
try:
try:
self.creds.refresh_session(http_driver.session)
except TypeError: # Credentials does not support session injection
http_driver.session = self.creds.refresh_session()
if http_driver is self._http_driver:
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
# Only reconfigure on refresh if it's a new session
kwargs = http_driver.configure_session(**config)
response = http_driver.send(request, **kwargs)
return response
except (oauth2.rfc6749.errors.InvalidGrantError,
oauth2.rfc6749.errors.TokenExpiredError) as err:
msg = "Token expired or is invalid."
raise_with_traceback(TokenExpiredError, msg, err)
except (requests.RequestException,
oauth2.rfc6749.errors.OAuth2Error) as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)
pipeline_response = self.config.pipeline.run(request, **kwargs)
# There is too much thing that expects this method to return a "requests.Response"
# to break it in a compatible release.
# Also, to be pragmatic in the "sync" world "requests" rules anyway.
# However, attach the Universal HTTP response
# to get the streaming generator.
response = pipeline_response.http_response.internal_response
response._universal_http_response = pipeline_response.http_response
response.context = pipeline_response.context
return response
finally:
self._close_local_session_if_necessary(response, http_driver, kwargs['stream'])
self._close_local_session_if_necessary(response, kwargs['stream'])
def _close_local_session_if_necessary(self, response, http_driver, stream):
# Do NOT close session if using my own HTTP driver. No exception.
if self._http_driver is http_driver:
return
def _close_local_session_if_necessary(self, response, stream):
# Here, it's a local session, I might close it.
if not response or not stream:
http_driver.session.close()
if not self.config.keep_alive and (not response or not stream):
self.config.pipeline._sender.driver.session.close()
def stream_download(self, data, callback):
# type: (Union[requests.Response, ClientResponse], Callable) -> Iterator[bytes]
"""Generator for streaming request body data.
:param data: A response object to be streamed.
:param callback: Custom callback for monitoring progress.
"""
block = self.config.connection.data_block_size
if not data._content_consumed:
with contextlib.closing(data) as response:
for chunk in response.iter_content(block):
if not chunk:
break
if callback and callable(callback):
callback(chunk, response=response)
yield chunk
else:
for chunk in data.iter_content(block):
if not chunk:
break
if callback and callable(callback):
callback(chunk, response=data)
yield chunk
data.close()
try:
# Assume this is ClientResponse, which it should be if backward compat was not important
return cast(ClientResponse, data).stream_download(block, callback)
except AttributeError:
try:
# Assume this is the patched requests.Response from "send"
return data._universal_http_response.stream_download(block, callback) # type: ignore
except AttributeError:
# Assume this is a raw requests.Response
from .universal_http.requests import RequestsClientResponse
response = RequestsClientResponse(None, data)
return response.stream_download(block, callback)
def stream_upload(self, data, callback):
"""Generator for streaming request body data.
@ -446,8 +299,8 @@ class ServiceClient(object):
DeprecationWarning)
self.config.headers[header] = value
def get(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def get(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a GET request object.
:param str url: The request URL.
@ -455,12 +308,12 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request = self._request('GET', url, params, headers, content, form_content)
request.method = 'GET'
return request
def put(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def put(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a PUT request object.
:param str url: The request URL.
@ -468,12 +321,11 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'PUT'
request = self._request('PUT', url, params, headers, content, form_content)
return request
def post(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def post(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a POST request object.
:param str url: The request URL.
@ -481,12 +333,11 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'POST'
request = self._request('POST', url, params, headers, content, form_content)
return request
def head(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def head(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a HEAD request object.
:param str url: The request URL.
@ -494,12 +345,11 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'HEAD'
request = self._request('HEAD', url, params, headers, content, form_content)
return request
def patch(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def patch(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a PATCH request object.
:param str url: The request URL.
@ -507,12 +357,11 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'PATCH'
request = self._request('PATCH', url, params, headers, content, form_content)
return request
def delete(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def delete(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a DELETE request object.
:param str url: The request URL.
@ -520,12 +369,11 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'DELETE'
request = self._request('DELETE', url, params, headers, content, form_content)
return request
def merge(self, url=None, params=None, headers=None, content=None, form_content=None):
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
def merge(self, url, params=None, headers=None, content=None, form_content=None):
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
"""Create a MERGE request object.
:param str url: The request URL.
@ -533,6 +381,5 @@ class ServiceClient(object):
:param dict headers: Headers
:param dict form_content: Form content
"""
request = self._request(url, params, headers, content, form_content)
request.method = 'MERGE'
request = self._request('MERGE', url, params, headers, content, form_content)
return request

Просмотреть файл

@ -0,0 +1,458 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
import abc
try:
import configparser
from configparser import NoOptionError
except ImportError:
import ConfigParser as configparser # type: ignore
from ConfigParser import NoOptionError # type: ignore
import json
import logging
import os.path
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import xml.etree.ElementTree as ET
from typing import TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, Tuple, Callable, Iterator # pylint: disable=unused-import
HTTPResponseType = TypeVar("HTTPResponseType", bound='HTTPClientResponse')
# This file is NOT using any "requests" HTTP implementation
# However, the CaseInsensitiveDict is handy.
# If one day we reach the point where "requests" can be skip totally,
# might provide our own implementation
from requests.structures import CaseInsensitiveDict
from ..exceptions import ClientRequestError, raise_with_traceback
if TYPE_CHECKING:
from ..serialization import Model # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
try:
ABC = abc.ABC
except AttributeError: # Python 2.7, abc exists, but not ABC
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) # type: ignore
try:
from contextlib import AbstractContextManager # type: ignore
except ImportError: # Python <= 3.5
class AbstractContextManager(object): # type: ignore
def __enter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
class HTTPSender(AbstractContextManager, ABC):
"""An http sender ABC.
"""
@abc.abstractmethod
def send(self, request, **config):
# type: (ClientRequest, Any) -> ClientResponse
"""Send the request using this HTTP sender.
"""
pass
class HTTPSenderConfiguration(object):
"""HTTP sender configuration.
This is composed of generic HTTP configuration, and could be use as a common
HTTP configuration format.
:param str filepath: Path to existing config file (optional).
"""
def __init__(self, filepath=None):
# Communication configuration
self.connection = ClientConnection()
# Headers (sent with every requests)
self.headers = {} # type: Dict[str, str]
# ProxyConfiguration
self.proxies = ClientProxies()
# Redirect configuration
self.redirect_policy = ClientRedirectPolicy()
self._config = configparser.ConfigParser()
self._config.optionxform = str # type: ignore
if filepath:
self.load(filepath)
def _clear_config(self):
# type: () -> None
"""Clearout config object in memory."""
for section in self._config.sections():
self._config.remove_section(section)
def save(self, filepath):
# type: (str) -> None
"""Save current configuration to file.
:param str filepath: Path to file where settings will be saved.
:raises: ValueError if supplied filepath cannot be written to.
"""
sections = [
"Connection",
"Proxies",
"RedirectPolicy"]
for section in sections:
self._config.add_section(section)
self._config.set("Connection", "timeout", self.connection.timeout)
self._config.set("Connection", "verify", self.connection.verify)
self._config.set("Connection", "cert", self.connection.cert)
self._config.set("Proxies", "proxies", self.proxies.proxies)
self._config.set("Proxies", "env_settings",
self.proxies.use_env_settings)
self._config.set("RedirectPolicy", "allow", self.redirect_policy.allow)
self._config.set("RedirectPolicy", "max_redirects",
self.redirect_policy.max_redirects)
try:
with open(filepath, 'w') as configfile:
self._config.write(configfile)
except (KeyError, EnvironmentError):
error = "Supplied config filepath invalid."
raise_with_traceback(ValueError, error)
finally:
self._clear_config()
def load(self, filepath):
# type: (str) -> None
"""Load configuration from existing file.
:param str filepath: Path to existing config file.
:raises: ValueError if supplied config file is invalid.
"""
try:
self._config.read(filepath)
import ast
self.connection.timeout = \
self._config.getint("Connection", "timeout")
self.connection.verify = \
self._config.getboolean("Connection", "verify")
self.connection.cert = \
self._config.get("Connection", "cert")
self.proxies.proxies = \
ast.literal_eval(self._config.get("Proxies", "proxies"))
self.proxies.use_env_settings = \
self._config.getboolean("Proxies", "env_settings")
self.redirect_policy.allow = \
self._config.getboolean("RedirectPolicy", "allow")
self.redirect_policy.max_redirects = \
self._config.getint("RedirectPolicy", "max_redirects")
except (ValueError, EnvironmentError, NoOptionError):
error = "Supplied config file incompatible."
raise_with_traceback(ValueError, error)
finally:
self._clear_config()
class ClientRequest(object):
"""Represents a HTTP request.
URL can be given without query parameters, to be added later using "format_parameters".
Instance can be created without data, to be added later using "add_content"
Instance can be created without files, to be added later using "add_formdata"
:param str method: HTTP method (GET, HEAD, etc.)
:param str url: At least complete scheme/host/path
:param dict[str,str] headers: HTTP headers
:param files: Files list.
:param data: Body to be sent.
:type data: bytes or str.
"""
def __init__(self, method, url, headers=None, files=None, data=None):
# type: (str, str, Mapping[str, str], Any, Any) -> None
self.method = method
self.url = url
self.headers = CaseInsensitiveDict(headers)
self.files = files
self.data = data
def __repr__(self):
return '<ClientRequest [%s]>' % (self.method)
@property
def body(self):
"""Alias to data."""
return self.data
@body.setter
def body(self, value):
self.data = value
def format_parameters(self, params):
# type: (Dict[str, str]) -> None
"""Format parameters into a valid query string.
It's assumed all parameters have already been quoted as
valid URL strings.
:param dict params: A dictionary of parameters.
"""
query = urlparse(self.url).query
if query:
self.url = self.url.partition('?')[0]
existing_params = {
p[0]: p[-1]
for p in [p.partition('=') for p in query.split('&')]
}
params.update(existing_params)
query_params = ["{}={}".format(k, v) for k, v in params.items()]
query = '?' + '&'.join(query_params)
self.url = self.url + query
def add_content(self, data):
# type: (Optional[Union[Dict[str, Any], ET.Element]]) -> None
"""Add a body to the request.
:param data: Request body data, can be a json serializable
object (e.g. dictionary) or a generator (e.g. file data).
"""
if data is None:
return
if isinstance(data, ET.Element):
bytes_data = ET.tostring(data, encoding="utf8")
self.headers['Content-Length'] = str(len(bytes_data))
self.data = bytes_data
return
# By default, assume JSON
try:
self.data = json.dumps(data)
self.headers['Content-Length'] = str(len(self.data))
except TypeError:
self.data = data
@staticmethod
def _format_data(data):
# type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]
"""Format field data according to whether it is a stream or
a string for a form-data request.
:param data: The request field data.
:type data: str or file-like object.
"""
if hasattr(data, 'read'):
data = cast(IO, data)
data_name = None
try:
if data.name[0] != '<' and data.name[-1] != '>':
data_name = os.path.basename(data.name)
except (AttributeError, TypeError):
pass
return (data_name, data, "application/octet-stream")
return (None, cast(str, data))
def add_formdata(self, content=None):
# type: (Optional[Dict[str, str]]) -> None
"""Add data as a multipart form-data request to the request.
We only deal with file-like objects or strings at this point.
The requests is not yet streamed.
:param dict headers: Any headers to add to the request.
:param dict content: Dictionary of the fields of the formdata.
"""
if content is None:
content = {}
content_type = self.headers.pop('Content-Type', None) if self.headers else None
if content_type and content_type.lower() == 'application/x-www-form-urlencoded':
# Do NOT use "add_content" that assumes input is JSON
self.data = {f: d for f, d in content.items() if d is not None}
else: # Assume "multipart/form-data"
self.files = {f: self._format_data(d) for f, d in content.items() if d is not None}
class HTTPClientResponse(object):
"""Represent a HTTP response.
No body is defined here on purpose, since async pipeline
will provide async ways to access the body
You have two differents types of body:
- Full in-memory using "body" as bytes
"""
def __init__(self, request, internal_response):
# type: (ClientRequest, Any) -> None
self.request = request
self.internal_response = internal_response
self.status_code = None # type: Optional[int]
self.headers = {} # type: Dict[str, str]
self.reason = None # type: Optional[str]
def body(self):
# type: () -> bytes
"""Return the whole body as bytes in memory.
"""
pass
def text(self, encoding=None):
# type: (str) -> str
"""Return the whole body as a string.
:param str encoding: The encoding to apply. If None, use "utf-8".
Implementation can be smarter if they want (using headers).
"""
return self.body().decode(encoding or "utf-8")
def raise_for_status(self):
"""Raise for status. Should be overriden, but basic implementation provided.
"""
if self.status_code >= 400:
raise ClientRequestError("Received status code {}".format(self.status_code))
class ClientResponse(HTTPClientResponse):
def stream_download(self, chunk_size=None, callback=None):
# type: (Optional[int], Optional[Callable]) -> Iterator[bytes]
"""Generator for streaming request body data.
Should be implemented by sub-classes if streaming download
is supported.
:param callback: Custom callback for monitoring progress.
:param int chunk_size:
"""
pass
class ClientRedirectPolicy(object):
"""Redirect configuration settings.
"""
def __init__(self):
self.allow = True
self.max_redirects = 30
def __bool__(self):
# type: () -> bool
"""Whether redirects are allowed."""
return self.allow
def __call__(self):
# type: () -> int
"""Return configuration to be applied to connection."""
debug = "Configuring redirects: allow=%r, max=%r"
_LOGGER.debug(debug, self.allow, self.max_redirects)
return self.max_redirects
class ClientProxies(object):
"""Proxy configuration settings.
Proxies can also be configured using HTTP_PROXY and HTTPS_PROXY
environment variables, in which case set use_env_settings to True.
"""
def __init__(self):
self.proxies = {}
self.use_env_settings = True
def __call__(self):
# type: () -> Dict[str, str]
"""Return configuration to be applied to connection."""
proxy_string = "\n".join(
[" {}: {}".format(k, v) for k, v in self.proxies.items()])
_LOGGER.debug("Configuring proxies: %r", proxy_string)
debug = "Evaluate proxies against ENV settings: %r"
_LOGGER.debug(debug, self.use_env_settings)
return self.proxies
def add(self, protocol, proxy_url):
# type: (str, str) -> None
"""Add proxy.
:param str protocol: Protocol for which proxy is to be applied. Can
be 'http', 'https', etc. Can also include host.
:param str proxy_url: The proxy URL. Where basic auth is required,
use the format: http://user:password@host
"""
self.proxies[protocol] = proxy_url
class ClientConnection(object):
"""Request connection configuration settings.
"""
def __init__(self):
self.timeout = 100
self.verify = True
self.cert = None
self.data_block_size = 4096
def __call__(self):
# type: () -> Dict[str, Union[str, int]]
"""Return configuration to be applied to connection."""
debug = "Configuring request: timeout=%r, verify=%r, cert=%r"
_LOGGER.debug(debug, self.timeout, self.verify, self.cert)
return {'timeout': self.timeout,
'verify': self.verify,
'cert': self.cert}
__all__ = [
'ClientRequest',
'ClientResponse',
'HTTPSender',
# Generic HTTP configuration
'HTTPSenderConfiguration',
'ClientRedirectPolicy',
'ClientProxies',
'ClientConnection'
]
try:
from .async_abc import AsyncHTTPSender, AsyncClientResponse # pylint: disable=unused-import
from .async_abc import __all__ as _async_all
__all__ += _async_all
except SyntaxError: # Python 2
pass

Просмотреть файл

@ -0,0 +1,100 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Any, Callable, AsyncIterator, Optional
import aiohttp
from . import AsyncHTTPSender, ClientRequest, AsyncClientResponse
# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
class AioHTTPSender(AsyncHTTPSender):
"""AioHttp HTTP sender implementation.
"""
def __init__(self, *, loop=None):
self._session = aiohttp.ClientSession(loop=loop)
async def __aenter__(self):
await self._session.__aenter__()
return self
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
await self._session.__aexit__(*exc_details)
async def send(self, request: ClientRequest, **config: Any) -> AsyncClientResponse:
"""Send the request using this HTTP sender.
Will pre-load the body into memory to be available with a sync method.
pass stream=True to avoid this behavior.
"""
result = await self._session.request(
request.method,
request.url,
**config
)
response = AioHttpClientResponse(request, result)
if not config.get("stream", False):
await response.load_body()
return response
class AioHttpClientResponse(AsyncClientResponse):
def __init__(self, request: ClientRequest, aiohttp_response: aiohttp.ClientResponse) -> None:
super(AioHttpClientResponse, self).__init__(request, aiohttp_response)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
self.status_code = aiohttp_response.status
self.headers = aiohttp_response.headers
self.reason = aiohttp_response.reason
self._body = None
def body(self) -> bytes:
"""Return the whole body as bytes in memory.
"""
if not self._body:
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
return self._body
async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._body = await self.internal_response.read()
def raise_for_status(self):
self.internal_response.raise_for_status()
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIterator[bytes]:
"""Generator for streaming request body data.
"""
chunk_size = chunk_size or CONTENT_CHUNK_SIZE
async def async_gen(resp):
while True:
chunk = await resp.content.read(chunk_size)
if not chunk:
break
callback(chunk, resp)
return async_gen(self.internal_response)

Просмотреть файл

@ -0,0 +1,90 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import abc
from typing import Any, List, Union, Callable, AsyncIterator, Optional
try:
from contextlib import AbstractAsyncContextManager # type: ignore
except ImportError: # Python <= 3.7
class AbstractAsyncContextManager(object): # type: ignore
async def __aenter__(self):
"""Return `self` upon entering the runtime context."""
return self
@abc.abstractmethod
async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
from . import ClientRequest, HTTPClientResponse
class AsyncClientResponse(HTTPClientResponse):
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIterator[bytes]:
"""Generator for streaming request body data.
Should be implemented by sub-classes if streaming download
is supported.
:param callback: Custom callback for monitoring progress.
:param int chunk_size:
"""
pass
class AsyncHTTPSender(AbstractAsyncContextManager, abc.ABC):
"""An http sender ABC.
"""
@abc.abstractmethod
async def send(self, request: ClientRequest, **config: Any) -> AsyncClientResponse:
"""Send the request using this HTTP sender.
"""
pass
def build_context(self) -> Any:
"""Allow the sender to build a context that will be passed
across the pipeline with the request.
Return type has no constraints. Implementation is not
required and None by default.
"""
return None
def __enter__(self):
raise TypeError("Use 'async with' instead")
def __exit__(self, exc_type, exc_val, exc_tb):
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
__all__ = [
'AsyncHTTPSender',
'AsyncClientResponse'
]

Просмотреть файл

@ -0,0 +1,240 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import asyncio
from collections.abc import AsyncIterator
import functools
import logging
from typing import Any, Callable, Optional, AsyncIterator as AsyncIteratorType
from oauthlib import oauth2
import requests
from requests.models import CONTENT_CHUNK_SIZE
from ..exceptions import (
TokenExpiredError,
ClientRequestError,
raise_with_traceback)
from . import AsyncHTTPSender, ClientRequest, AsyncClientResponse
from .requests import (
BasicRequestsHTTPSender,
RequestsHTTPSender,
HTTPRequestsClientResponse
)
_LOGGER = logging.getLogger(__name__)
class AsyncBasicRequestsHTTPSender(BasicRequestsHTTPSender, AsyncHTTPSender): # type: ignore
async def __aenter__(self):
return super(AsyncBasicRequestsHTTPSender, self).__enter__()
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
return super(AsyncBasicRequestsHTTPSender, self).__exit__()
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
"""Send the request using this HTTP sender.
"""
# It's not recommended to provide its own session, and is mostly
# to enable some legacy code to plug correctly
session = kwargs.pop('session', self.session)
loop = kwargs.get("loop", asyncio.get_event_loop())
future = loop.run_in_executor(
None,
functools.partial(
session.request,
request.method,
request.url,
**kwargs
)
)
try:
return AsyncRequestsClientResponse(
request,
await future
)
except requests.RequestException as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)
class AsyncRequestsHTTPSender(AsyncBasicRequestsHTTPSender, RequestsHTTPSender): # type: ignore
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
"""Send the request using this HTTP sender.
"""
requests_kwargs = self._configure_send(request, **kwargs)
return await super(AsyncRequestsHTTPSender, self).send(request, **requests_kwargs)
class _MsrestStopIteration(Exception):
pass
def _msrest_next(iterator):
""""To avoid:
TypeError: StopIteration interacts badly with generators and cannot be raised into a Future
"""
try:
return next(iterator)
except StopIteration:
raise _MsrestStopIteration()
class StreamDownloadGenerator(AsyncIterator):
def __init__(self, response: requests.Response, user_callback: Optional[Callable] = None, block: Optional[int] = None) -> None:
self.response = response
self.block = block or CONTENT_CHUNK_SIZE
self.user_callback = user_callback
self.iter_content_func = self.response.iter_content(self.block)
async def __anext__(self):
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(
None,
_msrest_next,
self.iter_content_func,
)
if not chunk:
raise _MsrestStopIteration()
if self.user_callback and callable(self.user_callback):
self.user_callback(chunk, self.response)
return chunk
except _MsrestStopIteration:
self.response.close()
raise StopAsyncIteration()
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
self.response.close()
raise
class AsyncRequestsClientResponse(AsyncClientResponse, HTTPRequestsClientResponse):
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIteratorType[bytes]:
"""Generator for streaming request body data.
:param callback: Custom callback for monitoring progress.
:param int chunk_size:
"""
return StreamDownloadGenerator(
self.internal_response,
callback,
chunk_size
)
# Trio support
try:
import trio
class TrioStreamDownloadGenerator(AsyncIterator):
def __init__(self, response: requests.Response, user_callback: Optional[Callable] = None, block: Optional[int] = None) -> None:
self.response = response
self.block = block or CONTENT_CHUNK_SIZE
self.user_callback = user_callback
self.iter_content_func = self.response.iter_content(self.block)
async def __anext__(self):
try:
chunk = await trio.run_sync_in_worker_thread(
_msrest_next,
self.iter_content_func,
)
if not chunk:
raise _MsrestStopIteration()
if self.user_callback and callable(self.user_callback):
self.user_callback(chunk, self.response)
return chunk
except _MsrestStopIteration:
self.response.close()
raise StopAsyncIteration()
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
self.response.close()
raise
class TrioAsyncRequestsClientResponse(AsyncClientResponse, HTTPRequestsClientResponse):
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIteratorType[bytes]:
"""Generator for streaming request body data.
:param callback: Custom callback for monitoring progress.
:param int chunk_size:
"""
return TrioStreamDownloadGenerator(
self.internal_response,
callback,
chunk_size
)
class AsyncTrioBasicRequestsHTTPSender(BasicRequestsHTTPSender, AsyncHTTPSender): # type: ignore
async def __aenter__(self):
return super(AsyncTrioBasicRequestsHTTPSender, self).__enter__()
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
return super(AsyncTrioBasicRequestsHTTPSender, self).__exit__()
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
"""Send the request using this HTTP sender.
"""
# It's not recommended to provide its own session, and is mostly
# to enable some legacy code to plug correctly
session = kwargs.pop('session', self.session)
trio_limiter = kwargs.get("trio_limiter", None)
future = trio.run_sync_in_worker_thread(
functools.partial(
session.request,
request.method,
request.url,
**kwargs
),
limiter=trio_limiter
)
try:
return TrioAsyncRequestsClientResponse(
request,
await future
)
except requests.RequestException as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)
class AsyncTrioRequestsHTTPSender(AsyncTrioBasicRequestsHTTPSender, RequestsHTTPSender): # type: ignore
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
"""Send the request using this HTTP sender.
"""
requests_kwargs = self._configure_send(request, **kwargs)
return await super(AsyncTrioRequestsHTTPSender, self).send(request, **requests_kwargs)
except ImportError:
# trio not installed
pass

Просмотреть файл

@ -0,0 +1,459 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
"""
This module is the requests implementation of Pipeline ABC
"""
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
import contextlib
import logging
import threading
from typing import TYPE_CHECKING, List, Callable, Iterator, Any, Union, Dict, Optional # pylint: disable=unused-import
import warnings
from oauthlib import oauth2
import requests
from requests.models import CONTENT_CHUNK_SIZE
from urllib3 import Retry # Needs requests 2.16 at least to be safe
from ..exceptions import (
TokenExpiredError,
ClientRequestError,
raise_with_traceback)
from . import HTTPSender, HTTPClientResponse, ClientResponse, HTTPSenderConfiguration
if TYPE_CHECKING:
from . import ClientRequest # pylint: disable=unused-import
_LOGGER = logging.getLogger(__name__)
class HTTPRequestsClientResponse(HTTPClientResponse):
def __init__(self, request, requests_response):
super(HTTPRequestsClientResponse, self).__init__(request, requests_response)
self.status_code = requests_response.status_code
self.headers = requests_response.headers
self.reason = requests_response.reason
def body(self):
return self.internal_response.content
def text(self, encoding=None):
if encoding:
self.internal_response.encoding = encoding
return self.internal_response.text
def raise_for_status(self):
self.internal_response.raise_for_status()
class RequestsClientResponse(HTTPRequestsClientResponse, ClientResponse):
def stream_download(self, chunk_size=None, callback=None):
# type: (Optional[int], Optional[Callable]) -> Iterator[bytes]
"""Generator for streaming request body data.
:param callback: Custom callback for monitoring progress.
:param int chunk_size:
"""
chunk_size = chunk_size or CONTENT_CHUNK_SIZE
with contextlib.closing(self.internal_response) as response:
# https://github.com/PyCQA/pylint/issues/1437
for chunk in response.iter_content(chunk_size): # pylint: disable=no-member
if not chunk:
break
if callback and callable(callback):
callback(chunk, response=response)
yield chunk
class BasicRequestsHTTPSender(HTTPSender):
"""Implements a basic requests HTTP sender.
Since requests team recommends to use one session per requests, you should
not consider this class as thread-safe, since it will use one Session
per instance.
In this simple implementation:
- You provide the configured session if you want to, or a basic session is created.
- All kwargs received by "send" are sent to session.request directly
"""
def __init__(self, session=None):
# type: (Optional[requests.Session]) -> None
self.session = session or requests.Session()
def __enter__(self):
# type: () -> BasicRequestsHTTPSender
return self
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
self.close()
def close(self):
self.session.close()
def send(self, request, **kwargs):
# type: (ClientRequest, Any) -> ClientResponse
"""Send request object according to configuration.
Allowed kwargs are:
- session : will override the driver session and use yours. Should NOT be done unless really required.
- anything else is sent straight to requests.
:param ClientRequest request: The request object to be sent.
"""
# It's not recommended to provide its own session, and is mostly
# to enable some legacy code to plug correctly
session = kwargs.pop('session', self.session)
try:
response = session.request(
request.method,
request.url,
**kwargs)
except requests.RequestException as err:
msg = "Error occurred in request."
raise_with_traceback(ClientRequestError, msg, err)
return RequestsClientResponse(request, response)
def _patch_redirect(session):
# type: (requests.Session) -> None
"""Whether redirect policy should be applied based on status code.
HTTP spec says that on 301/302 not HEAD/GET, should NOT redirect.
But requests does, to follow browser more than spec
https://github.com/requests/requests/blob/f6e13ccfc4b50dc458ee374e5dba347205b9a2da/requests/sessions.py#L305-L314
This patches "requests" to be more HTTP compliant.
Note that this is super dangerous, since technically this is not public API.
"""
def enforce_http_spec(resp, request):
if resp.status_code in (301, 302) and \
request.method not in ['GET', 'HEAD']:
return False
return True
redirect_logic = session.resolve_redirects
def wrapped_redirect(resp, req, **kwargs):
attempt = enforce_http_spec(resp, req)
return redirect_logic(resp, req, **kwargs) if attempt else []
wrapped_redirect.is_msrest_patched = True # type: ignore
session.resolve_redirects = wrapped_redirect # type: ignore
class RequestsHTTPSender(BasicRequestsHTTPSender):
"""A requests HTTP sender that can consume a msrest.Configuration object.
This instance will consume the following configuration attributes:
- connection
- proxies
- retry_policy
- redirect_policy
- enable_http_logger
- hooks
- session_configuration_callback
"""
_protocols = ['http://', 'https://']
# Set of authorized kwargs at the operation level
_REQUESTS_KWARGS = [
'cookies',
'verify',
'timeout',
'allow_redirects',
'proxies',
'verify',
'cert'
]
def __init__(self, config=None):
# type: (Optional[RequestHTTPSenderConfiguration]) -> None
self._session_mapping = threading.local()
self.config = config or RequestHTTPSenderConfiguration()
super(RequestsHTTPSender, self).__init__()
@property # type: ignore
def session(self):
try:
return self._session_mapping.session
except AttributeError:
self._session_mapping.session = requests.Session()
self._init_session(self._session_mapping.session)
return self._session_mapping.session
@session.setter
def session(self, value):
self._init_session(value)
self._session_mapping.session = value
def _init_session(self, session):
# type: (requests.Session) -> None
"""Init session level configuration of requests.
This is initialization I want to do once only on a session.
"""
_patch_redirect(session)
# Change max_retries in current all installed adapters
max_retries = self.config.retry_policy()
for protocol in self._protocols:
session.adapters[protocol].max_retries = max_retries
def _configure_send(self, request, **kwargs):
# type: (ClientRequest, Any) -> Dict[str, str]
"""Configure the kwargs to use with requests.
See "send" for kwargs details.
:param ClientRequest request: The request object to be sent.
:returns: The requests.Session.request kwargs
:rtype: dict[str,str]
"""
requests_kwargs = {} # type: Any
session = kwargs.pop('session', self.session)
# If custom session was not create here
if session is not self.session:
self._init_session(session)
session.max_redirects = int(self.config.redirect_policy())
session.trust_env = bool(self.config.proxies.use_env_settings)
# Initialize requests_kwargs with "config" value
requests_kwargs.update(self.config.connection())
requests_kwargs['allow_redirects'] = bool(self.config.redirect_policy)
requests_kwargs['headers'] = self.config.headers.copy()
proxies = self.config.proxies()
if proxies:
requests_kwargs['proxies'] = proxies
# Replace by operation level kwargs
# We allow some of them, since some like stream or json are controled by msrest
for key in kwargs:
if key in self._REQUESTS_KWARGS:
requests_kwargs[key] = kwargs[key]
# Hooks. Deprecated, should be a policy
def make_user_hook_cb(user_hook, session):
def user_hook_cb(r, *args, **kwargs):
kwargs.setdefault("msrest", {})['session'] = session
return user_hook(r, *args, **kwargs)
return user_hook_cb
hooks = []
for user_hook in self.config.hooks:
hooks.append(make_user_hook_cb(user_hook, self.session))
if hooks:
requests_kwargs['hooks'] = {'response': hooks}
# Configuration callback. Deprecated, should be a policy
output_kwargs = self.config.session_configuration_callback(
session,
self.config,
kwargs,
**requests_kwargs
)
if output_kwargs is not None:
requests_kwargs = output_kwargs
# If custom session was not create here
if session is not self.session:
requests_kwargs['session'] = session
### Autorest forced kwargs now ###
# If Autorest needs this response to be streamable. True for compat.
requests_kwargs['stream'] = kwargs.get('stream', True)
if request.files:
requests_kwargs['files'] = request.files
elif request.data:
requests_kwargs['data'] = request.data
requests_kwargs['headers'].update(request.headers)
return requests_kwargs
def send(self, request, **kwargs):
# type: (ClientRequest, Any) -> ClientResponse
"""Send request object according to configuration.
Available kwargs:
- session : will override the driver session and use yours. Should NOT be done unless really required.
- A subset of what requests.Session.request can receive:
- cookies
- verify
- timeout
- allow_redirects
- proxies
- verify
- cert
Everything else will be silently ignored.
:param ClientRequest request: The request object to be sent.
"""
requests_kwargs = self._configure_send(request, **kwargs)
return super(RequestsHTTPSender, self).send(request, **requests_kwargs)
class ClientRetryPolicy(object):
"""Retry configuration settings.
Container for retry policy object.
"""
safe_codes = [i for i in range(500) if i != 408] + [501, 505]
def __init__(self):
self.policy = Retry()
self.policy.total = 3
self.policy.connect = 3
self.policy.read = 3
self.policy.backoff_factor = 0.8
self.policy.BACKOFF_MAX = 90
retry_codes = [i for i in range(999) if i not in self.safe_codes]
self.policy.status_forcelist = retry_codes
self.policy.method_whitelist = ['HEAD', 'TRACE', 'GET', 'PUT',
'OPTIONS', 'DELETE', 'POST', 'PATCH']
def __call__(self):
# type: () -> Retry
"""Return configuration to be applied to connection."""
debug = ("Configuring retry: max_retries=%r, "
"backoff_factor=%r, max_backoff=%r")
_LOGGER.debug(
debug, self.retries, self.backoff_factor, self.max_backoff)
return self.policy
@property
def retries(self):
# type: () -> int
"""Total number of allowed retries."""
return self.policy.total
@retries.setter
def retries(self, value):
# type: (int) -> None
self.policy.total = value
self.policy.connect = value
self.policy.read = value
@property
def backoff_factor(self):
# type: () -> Union[int, float]
"""Factor by which back-off delay is incementally increased."""
return self.policy.backoff_factor
@backoff_factor.setter
def backoff_factor(self, value):
# type: (Union[int, float]) -> None
self.policy.backoff_factor = value
@property
def max_backoff(self):
# type: () -> int
"""Max retry back-off delay."""
return self.policy.BACKOFF_MAX
@max_backoff.setter
def max_backoff(self, value):
# type: (int) -> None
self.policy.BACKOFF_MAX = value
def default_session_configuration_callback(session, global_config, local_config, **kwargs): # pylint: disable=unused-argument
# type: (requests.Session, RequestHTTPSenderConfiguration, Dict[str,str], str) -> Dict[str, str]
"""Configuration callback if you need to change default session configuration.
:param requests.Session session: The session.
:param Configuration global_config: The global configuration.
:param dict[str,str] local_config: The on-the-fly configuration passed on the call.
:param dict[str,str] kwargs: The current computed values for session.request method.
:return: Must return kwargs, to be passed to session.request. If None is return, initial kwargs will be used.
:rtype: dict[str,str]
"""
return kwargs
class RequestHTTPSenderConfiguration(HTTPSenderConfiguration):
"""Requests specific HTTP sender configuration.
:param str filepath: Path to existing config file (optional).
"""
def __init__(self, filepath=None):
# type: (Optional[str]) -> None
super(RequestHTTPSenderConfiguration, self).__init__()
# Retry configuration
self.retry_policy = ClientRetryPolicy()
# Requests hooks. Must respect requests hook callback signature
# Note that we will inject the following parameters:
# - kwargs['msrest']['session'] with the current session
self.hooks = [] # type: List[Callable[[requests.Response, str, str], None]]
self.session_configuration_callback = default_session_configuration_callback
if filepath:
self.load(filepath)
def save(self, filepath):
"""Save current configuration to file.
:param str filepath: Path to file where settings will be saved.
:raises: ValueError if supplied filepath cannot be written to.
"""
self._config.add_section("RetryPolicy")
self._config.set("RetryPolicy", "retries", str(self.retry_policy.retries))
self._config.set("RetryPolicy", "backoff_factor",
str(self.retry_policy.backoff_factor))
self._config.set("RetryPolicy", "max_backoff",
str(self.retry_policy.max_backoff))
super(RequestHTTPSenderConfiguration, self).save(filepath)
def load(self, filepath):
try:
self.retry_policy.retries = \
self._config.getint("RetryPolicy", "retries")
self.retry_policy.backoff_factor = \
self._config.getfloat("RetryPolicy", "backoff_factor")
self.retry_policy.max_backoff = \
self._config.getint("RetryPolicy", "max_backoff")
except (ValueError, EnvironmentError, NoOptionError):
error = "Supplied config file incompatible."
raise_with_traceback(ValueError, error)
finally:
self._clear_config()
super(RequestHTTPSenderConfiguration, self).load(filepath)

Просмотреть файл

@ -25,4 +25,4 @@
# --------------------------------------------------------------------------
#: version of this package. Use msrest.__version__ instead
msrest_version = "0.5.5"
msrest_version = "0.6.0rc1"

44
pylintrc Normal file
Просмотреть файл

@ -0,0 +1,44 @@
[MASTER]
ignore-patterns=test_*
reports=no
[MESSAGES CONTROL]
# For all codes, run 'pylint --list-msgs' or go to 'https://pylint.readthedocs.io/en/latest/reference_guide/features.html'
# locally-disabled: Warning locally suppressed using disable-msg
# cyclic-import: because of https://github.com/PyCQA/pylint/issues/850
# too-many-arguments: Due to the nature of the CLI many commands have large arguments set which reflect in large arguments set in corresponding methods.
disable=missing-docstring,locally-disabled,fixme,cyclic-import,too-many-arguments,invalid-name,duplicate-code
[FORMAT]
max-line-length=120
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=yes
[DESIGN]
# Maximum number of locals for function / method body
max-locals=25
# Maximum number of branch for function / method body
max-branches=20
[SIMILARITIES]
min-similarity-lines=10
[BASIC]
# Naming hints based on PEP 8 (https://www.python.org/dev/peps/pep-0008/#naming-conventions).
# Consider these guidelines and not hard rules. Read PEP 8 for more details.
# The invalid-name checker must be **enabled** for these hints to be used.
include-naming-hint=yes
module-name-hint=lowercase (keep short; underscores are discouraged)
const-name-hint=UPPER_CASE_WITH_UNDERSCORES
class-name-hint=CapitalizedWords
class-attribute-name-hint=lower_case_with_underscores
attr-name-hint=lower_case_with_underscores
method-name-hint=lower_case_with_underscores
function-name-hint=lower_case_with_underscores
argument-name-hint=lower_case_with_underscores
variable-name-hint=lower_case_with_underscores
inlinevar-name-hint=lower_case_with_underscores (short is OK)

Просмотреть файл

@ -2,4 +2,7 @@
universal=1
[mypy]
ignore_missing_imports = True
ignore_missing_imports = True
[tool:pytest]
addopts = --durations=10

Просмотреть файл

@ -28,7 +28,7 @@ from setuptools import setup, find_packages
setup(
name='msrest',
version='0.5.5',
version='0.6.0rc1',
author='Microsoft Corporation',
packages=find_packages(exclude=["tests", "tests.*"]),
url=("https://github.com/Azure/msrest-for-python"),
@ -56,5 +56,9 @@ setup(
extras_require={
":python_version<'3.4'": ['enum34>=1.0.4'],
":python_version<'3.5'": ['typing'],
"async:python_version>='3.5'": [
'aiohttp>=3.0',
'aiodns'
],
}
)

Просмотреть файл

@ -0,0 +1,188 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import io
import asyncio
import json
import unittest
try:
from unittest import mock
except ImportError:
import mock
import sys
import pytest
import requests
from requests.adapters import HTTPAdapter
from oauthlib import oauth2
from msrest import ServiceClient
from msrest.authentication import OAuthTokenAuthentication
from msrest.configuration import Configuration
from msrest import Configuration
from msrest.exceptions import ClientRequestError, TokenExpiredError
from msrest.universal_http import ClientRequest
from msrest.universal_http.async_requests import AsyncRequestsClientResponse
@unittest.skipIf(sys.version_info < (3, 5, 2), "Async tests only on 3.5.2 minimal")
class TestServiceClient(object):
@pytest.mark.asyncio
async def test_client_send(self):
cfg = Configuration("/")
cfg.headers = {'Test': 'true'}
creds = mock.create_autospec(OAuthTokenAuthentication)
client = ServiceClient(creds, cfg)
req_response = requests.Response()
req_response._content = br'{"real": true}' # Has to be valid bytes JSON
req_response._content_consumed = True
req_response.status_code = 200
def side_effect(*args, **kwargs):
return req_response
session = mock.create_autospec(requests.Session)
session.request.side_effect = side_effect
session.adapters = {
"http://": HTTPAdapter(),
"https://": HTTPAdapter(),
}
# Be sure the mock does not trick me
assert not hasattr(session.resolve_redirects, 'is_msrest_patched')
client.config.async_pipeline._sender.driver.session = session
client._creds.signed_session.return_value = session
client._creds.refresh_session.return_value = session
request = ClientRequest('GET', '/')
await client.async_send(request, stream=False)
session.request.call_count = 0
session.request.assert_called_with(
'GET',
'/',
allow_redirects=True,
cert=None,
headers={
'User-Agent': cfg.user_agent,
'Test': 'true' # From global config
},
stream=False,
timeout=100,
verify=True
)
assert session.resolve_redirects.is_msrest_patched
request = client.get('/', headers={'id':'1234'}, content={'Test':'Data'})
await client.async_send(request, stream=False)
session.request.assert_called_with(
'GET',
'/',
data='{"Test": "Data"}',
allow_redirects=True,
cert=None,
headers={
'User-Agent': cfg.user_agent,
'Content-Length': '16',
'id':'1234',
'Accept': 'application/json',
'Test': 'true' # From global config
},
stream=False,
timeout=100,
verify=True
)
assert session.request.call_count == 1
session.request.call_count = 0
assert session.resolve_redirects.is_msrest_patched
request = client.get('/', headers={'id':'1234'}, content={'Test':'Data'})
session.request.side_effect = requests.RequestException("test")
with pytest.raises(ClientRequestError):
await client.async_send(request, test='value', stream=False)
session.request.assert_called_with(
'GET',
'/',
data='{"Test": "Data"}',
allow_redirects=True,
cert=None,
headers={
'User-Agent': cfg.user_agent,
'Content-Length': '16',
'id':'1234',
'Accept': 'application/json',
'Test': 'true' # From global config
},
stream=False,
timeout=100,
verify=True
)
assert session.request.call_count == 1
session.request.call_count = 0
assert session.resolve_redirects.is_msrest_patched
session.request.side_effect = oauth2.rfc6749.errors.InvalidGrantError("test")
with pytest.raises(TokenExpiredError):
await client.async_send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
assert session.request.call_count == 2
session.request.call_count = 0
session.request.side_effect = ValueError("test")
with pytest.raises(ValueError):
await client.async_send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
@pytest.mark.asyncio
async def test_client_stream_download(self):
req_response = requests.Response()
req_response._content = "abc"
req_response._content_consumed = True
req_response.status_code = 200
client_response = AsyncRequestsClientResponse(
None,
req_response
)
def user_callback(chunk, response):
assert response is req_response
assert chunk in ["a", "b", "c"]
async_iterator = client_response.stream_download(1, user_callback)
result = ""
async for value in async_iterator:
result += value
assert result == "abc"
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,173 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import sys
import unittest
import pytest
from msrest.paging import Paged
class FakePaged(Paged):
_attribute_map = {
'next_link': {'key': 'nextLink', 'type': 'str'},
'current_page': {'key': 'value', 'type': '[str]'}
}
def __init__(self, *args, **kwargs):
super(FakePaged, self).__init__(*args, **kwargs)
class TestPaging(object):
@pytest.mark.asyncio
async def test_basic_paging(self):
async def internal_paging(next_link=None, raw=False):
if not next_link:
return {
'nextLink': 'page2',
'value': ['value1.0', 'value1.1']
}
else:
return {
'nextLink': None,
'value': ['value2.0', 'value2.1']
}
deserialized = FakePaged(None, {}, async_command=internal_paging)
# 3.6 only : result_iterated = [obj async for obj in deserialized]
result_iterated = []
async for obj in deserialized:
result_iterated.append(obj)
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
@pytest.mark.asyncio
async def test_advance_paging(self):
async def internal_paging(next_link=None, raw=False):
if not next_link:
return {
'nextLink': 'page2',
'value': ['value1.0', 'value1.1']
}
else:
return {
'nextLink': None,
'value': ['value2.0', 'value2.1']
}
deserialized = FakePaged(None, {}, async_command=internal_paging)
page1 = await deserialized.async_advance_page()
assert ['value1.0', 'value1.1'] == page1
page2 = await deserialized.async_advance_page()
assert ['value2.0', 'value2.1'] == page2
with pytest.raises(StopAsyncIteration):
await deserialized.async_advance_page()
@pytest.mark.asyncio
async def test_get_paging(self):
async def internal_paging(next_link=None, raw=False):
if not next_link:
return {
'nextLink': 'page2',
'value': ['value1.0', 'value1.1']
}
elif next_link == 'page2':
return {
'nextLink': 'page3',
'value': ['value2.0', 'value2.1']
}
else:
return {
'nextLink': None,
'value': ['value3.0', 'value3.1']
}
deserialized = FakePaged(None, {}, async_command=internal_paging)
page2 = await deserialized.async_get('page2')
assert ['value2.0', 'value2.1'] == page2
page3 = await deserialized.async_get('page3')
assert ['value3.0', 'value3.1'] == page3
@pytest.mark.asyncio
async def test_reset_paging(self):
async def internal_paging(next_link=None, raw=False):
if not next_link:
return {
'nextLink': 'page2',
'value': ['value1.0', 'value1.1']
}
else:
return {
'nextLink': None,
'value': ['value2.0', 'value2.1']
}
deserialized = FakePaged(None, {}, async_command=internal_paging)
deserialized.reset()
# 3.6 only : result_iterated = [obj async for obj in deserialized]
result_iterated = []
async for obj in deserialized:
result_iterated.append(obj)
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
deserialized = FakePaged(None, {}, async_command=internal_paging)
# Push the iterator to the last element
async for element in deserialized:
if element == "value2.0":
break
deserialized.reset()
# 3.6 only : result_iterated = [obj async for obj in deserialized]
result_iterated = []
async for obj in deserialized:
result_iterated.append(obj)
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
@pytest.mark.asyncio
async def test_none_value(self):
async def internal_paging(next_link=None, raw=False):
return {
'nextLink': None,
'value': None
}
deserialized = FakePaged(None, {}, async_command=internal_paging)
# 3.6 only : result_iterated = [obj async for obj in deserialized]
result_iterated = []
async for obj in deserialized:
result_iterated.append(obj)
assert len(result_iterated) == 0

Просмотреть файл

@ -0,0 +1,128 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import sys
from msrest.universal_http import (
ClientRequest,
)
from msrest.universal_http.async_requests import (
AsyncRequestsHTTPSender,
AsyncTrioRequestsHTTPSender,
)
from msrest.pipeline import (
AsyncPipeline,
AsyncHTTPSender,
SansIOHTTPPolicy
)
from msrest.pipeline.async_requests import AsyncPipelineRequestsHTTPSender
from msrest.pipeline.universal import UserAgentPolicy
from msrest.pipeline.aiohttp import AioHTTPSender
from msrest.configuration import Configuration
import trio
import pytest
@pytest.mark.asyncio
async def test_sans_io_exception():
class BrokenSender(AsyncHTTPSender):
async def send(self, request, **config):
raise ValueError("Broken")
async def __aexit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
pipeline = AsyncPipeline([SansIOHTTPPolicy()], BrokenSender())
req = ClientRequest('GET', '/')
with pytest.raises(ValueError):
await pipeline.run(req)
class SwapExec(SansIOHTTPPolicy):
def on_exception(self, requests, **kwargs):
exc_type, exc_value, exc_traceback = sys.exc_info()
raise NotImplementedError(exc_value)
pipeline = AsyncPipeline([SwapExec()], BrokenSender())
with pytest.raises(NotImplementedError):
await pipeline.run(req)
@pytest.mark.asyncio
async def test_basic_aiohttp():
request = ClientRequest("GET", "http://bing.com")
policies = [
UserAgentPolicy("myusergant")
]
async with AsyncPipeline(policies) as pipeline:
response = await pipeline.run(request)
assert pipeline._sender.driver._session.closed
assert response.http_response.status_code == 200
@pytest.mark.asyncio
async def test_basic_async_requests():
request = ClientRequest("GET", "http://bing.com")
policies = [
UserAgentPolicy("myusergant")
]
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender()) as pipeline:
response = await pipeline.run(request)
assert response.http_response.status_code == 200
@pytest.mark.asyncio
async def test_conf_async_requests():
conf = Configuration("http://bing.com/")
request = ClientRequest("GET", "http://bing.com/")
policies = [
UserAgentPolicy("myusergant")
]
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender(AsyncRequestsHTTPSender(conf))) as pipeline:
response = await pipeline.run(request)
assert response.http_response.status_code == 200
def test_conf_async_trio_requests():
async def do():
conf = Configuration("http://bing.com/")
request = ClientRequest("GET", "http://bing.com/")
policies = [
UserAgentPolicy("myusergant")
]
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender(AsyncTrioRequestsHTTPSender(conf))) as pipeline:
return await pipeline.run(request)
response = trio.run(do)
assert response.http_response.status_code == 200

Просмотреть файл

@ -0,0 +1,167 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import asyncio
try:
from unittest import mock
except ImportError:
import mock
import pytest
from msrest.polling.async_poller import *
from msrest.service_client import ServiceClient
from msrest.serialization import Model
from msrest.configuration import Configuration
@pytest.mark.asyncio
async def test_abc_polling():
abc_polling = AsyncPollingMethod()
with pytest.raises(NotImplementedError):
abc_polling.initialize(None, None, None)
with pytest.raises(NotImplementedError):
await abc_polling.run()
with pytest.raises(NotImplementedError):
abc_polling.status()
with pytest.raises(NotImplementedError):
abc_polling.finished()
with pytest.raises(NotImplementedError):
abc_polling.resource()
@pytest.mark.asyncio
async def test_no_polling():
no_polling = AsyncNoPolling()
initial_response = "initial response"
def deserialization_cb(response):
assert response == initial_response
return "Treated: "+response
no_polling.initialize(None, initial_response, deserialization_cb)
await no_polling.run() # Should no raise and do nothing
assert no_polling.status() == "succeeded"
assert no_polling.finished()
assert no_polling.resource() == "Treated: "+initial_response
class PollingTwoSteps(AsyncPollingMethod):
"""An empty poller that returns the deserialized initial response.
"""
def __init__(self, sleep=0):
self._initial_response = None
self._deserialization_callback = None
self._sleep = sleep
def initialize(self, _, initial_response, deserialization_callback):
self._initial_response = initial_response
self._deserialization_callback = deserialization_callback
self._finished = False
async def run(self):
"""Empty run, no polling.
"""
self._finished = True
await asyncio.sleep(self._sleep) # Give me time to add callbacks!
def status(self):
"""Return the current status as a string.
:rtype: str
"""
return "succeeded" if self._finished else "running"
def finished(self):
"""Is this polling finished?
:rtype: bool
"""
return self._finished
def resource(self):
return self._deserialization_callback(self._initial_response)
@pytest.fixture
def client():
# We need a ServiceClient instance, but the poller itself don't use it, so we don't need
# Something functionnal
return ServiceClient(None, Configuration("http://example.org"))
@pytest.mark.asyncio
async def test_poller(client):
# Same the poller itself doesn't care about the initial_response, and there is no type constraint here
initial_response = "Initial response"
# Same for deserialization_callback, just pass to the polling_method
def deserialization_callback(response):
assert response == initial_response
return "Treated: "+response
method = AsyncNoPolling()
result = await async_poller(client, initial_response, deserialization_callback, method)
assert result == "Treated: "+initial_response
# Test with a basic Model
class MockedModel(Model):
called = False
@classmethod
def deserialize(cls, data):
assert data == initial_response
cls.called = True
result = await async_poller(client, initial_response, MockedModel, method)
assert MockedModel.called
# Test poller that method do a run
method = PollingTwoSteps(sleep=2)
result = await async_poller(client, initial_response, deserialization_callback, method)
assert result == "Treated: "+initial_response
@pytest.mark.asyncio
async def test_broken_poller(client):
with pytest.raises(ValueError):
await async_poller(None, None, None, None)
class NoPollingError(PollingTwoSteps):
async def run(self):
raise ValueError("Something bad happened")
initial_response = "Initial response"
def deserialization_callback(response):
return "Treated: "+response
method = NoPollingError()
with pytest.raises(ValueError) as excinfo:
await async_poller(client, initial_response, deserialization_callback, method)
assert "Something bad happened" in str(excinfo.value)

Просмотреть файл

@ -0,0 +1,88 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import sys
from msrest.universal_http import (
ClientRequest,
AsyncHTTPSender,
)
from msrest.universal_http.aiohttp import AioHTTPSender
from msrest.universal_http.async_requests import (
AsyncBasicRequestsHTTPSender,
AsyncRequestsHTTPSender,
AsyncTrioRequestsHTTPSender,
)
from msrest.configuration import Configuration
import trio
import pytest
@pytest.mark.asyncio
async def test_basic_aiohttp():
request = ClientRequest("GET", "http://bing.com")
async with AioHTTPSender() as sender:
response = await sender.send(request)
assert response.body() is not None
assert sender._session.closed
assert response.status_code == 200
@pytest.mark.asyncio
async def test_basic_async_requests():
request = ClientRequest("GET", "http://bing.com")
async with AsyncBasicRequestsHTTPSender() as sender:
response = await sender.send(request)
assert response.body() is not None
assert response.status_code == 200
@pytest.mark.asyncio
async def test_conf_async_requests():
conf = Configuration("http://bing.com/")
request = ClientRequest("GET", "http://bing.com/")
async with AsyncRequestsHTTPSender(conf) as sender:
response = await sender.send(request)
assert response.body() is not None
assert response.status_code == 200
def test_conf_async_trio_requests():
async def do():
conf = Configuration("http://bing.com/")
request = ClientRequest("GET", "http://bing.com/")
async with AsyncTrioRequestsHTTPSender(conf) as sender:
return await sender.send(request)
assert response.body() is not None
response = trio.run(do)
assert response.status_code == 200

31
tests/conftest.py Normal file
Просмотреть файл

@ -0,0 +1,31 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import sys
# Ignore collection of async tests for Python 2
collect_ignore = []
if sys.version_info < (3, 5):
collect_ignore.append("asynctests")

Просмотреть файл

@ -31,18 +31,30 @@ try:
from unittest import mock
except ImportError:
import mock
import sys
import requests
from requests.adapters import HTTPAdapter
from oauthlib import oauth2
from msrest import ServiceClient, SDKClient
from msrest.service_client import _RequestsHTTPDriver
from msrest.universal_http import (
ClientRequest,
ClientResponse
)
from msrest.universal_http.requests import (
RequestsHTTPSender,
RequestsClientResponse
)
from msrest.pipeline import (
HTTPSender,
Response,
)
from msrest.authentication import OAuthTokenAuthentication, Authentication
from msrest import Configuration
from msrest.exceptions import ClientRequestError, TokenExpiredError
from msrest.pipeline import ClientRequest
class TestServiceClient(unittest.TestCase):
@ -53,19 +65,6 @@ class TestServiceClient(unittest.TestCase):
self.creds = mock.create_autospec(OAuthTokenAuthentication)
return super(TestServiceClient, self).setUp()
def test_session_callback(self):
with _RequestsHTTPDriver(self.cfg) as driver:
def callback(session, global_config, local_config, **kwargs):
self.assertIs(session, driver.session)
self.assertIs(global_config, self.cfg)
self.assertTrue(local_config["test"])
return {'used_callback': True}
self.cfg.session_configuration_callback = callback
output_kwargs = driver.configure_session(**{"test": True})
self.assertTrue(output_kwargs['used_callback'])
def test_sdk_context_manager(self):
cfg = Configuration("http://127.0.0.1/")
@ -87,14 +86,16 @@ class TestServiceClient(unittest.TestCase):
with SDKClient(creds, cfg) as client:
assert cfg.keep_alive
req = client._client.get()
req = client._client.get('/')
try:
client._client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client._client.send(req, timeout=0)
except Exception:
pass
try:
client._client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client._client.send(req, timeout=0)
except Exception:
pass
@ -122,17 +123,18 @@ class TestServiceClient(unittest.TestCase):
with ServiceClient(creds, cfg) as client:
assert cfg.keep_alive
req = client.get()
req = client.get('/')
try:
client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client.send(req, timeout=0)
except Exception:
pass
try:
client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client.send(req, timeout=0)
except Exception:
pass
assert client._http_driver.session # Still alive
assert not cfg.keep_alive
assert creds.called == 2
@ -157,101 +159,60 @@ class TestServiceClient(unittest.TestCase):
creds = Creds()
client = ServiceClient(creds, cfg)
req = client.get()
req = client.get('/')
try:
client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client.send(req, timeout=0)
except Exception:
pass
try:
client.send(req) # Will fail, I don't care, that's not the point of the test
# Will fail, I don't care, that's not the point of the test
client.send(req, timeout=0)
except Exception:
pass
assert creds.called == 2
assert client._http_driver.session # Still alive
# Manually close the client in "keep_alive" mode
client.close()
def test_max_retries_on_default_adapter(self):
# max_retries must be applied only on the default adapters of requests
# If the user adds its own adapter, don't touch it
max_retries = self.cfg.retry_policy()
with _RequestsHTTPDriver(self.cfg) as driver:
driver.session.mount('http://example.org', HTTPAdapter())
driver.configure_session()
assert driver.session.adapters["http://"].max_retries is max_retries
assert driver.session.adapters["https://"].max_retries is max_retries
assert driver.session.adapters['http://example.org'].max_retries is not max_retries
def test_no_log(self):
# By default, no log handler for HTTP
with _RequestsHTTPDriver(self.cfg) as driver:
kwargs = driver.configure_session()
assert 'hooks' not in kwargs
# I can enable it per request
with _RequestsHTTPDriver(self.cfg) as driver:
kwargs = driver.configure_session(**{"enable_http_logger": True})
assert 'hooks' in kwargs
# I can enable it per request (bool value should be honored)
with _RequestsHTTPDriver(self.cfg) as driver:
kwargs = driver.configure_session(**{"enable_http_logger": False})
assert 'hooks' not in kwargs
# I can enable it globally
self.cfg.enable_http_logger = True
with _RequestsHTTPDriver(self.cfg) as driver:
kwargs = driver.configure_session()
assert 'hooks' in kwargs
# I can enable it globally and override it locally
self.cfg.enable_http_logger = True
with _RequestsHTTPDriver(self.cfg) as driver:
kwargs = driver.configure_session(**{"enable_http_logger": False})
assert 'hooks' not in kwargs
def test_client_request(self):
client = ServiceClient(self.creds, self.cfg)
obj = client.get()
cfg = Configuration("http://127.0.0.1/")
client = ServiceClient(self.creds, cfg)
obj = client.get('/')
self.assertEqual(obj.method, 'GET')
self.assertIsNone(obj.url)
self.assertEqual(obj.params, {})
self.assertEqual(obj.url, "http://127.0.0.1/")
obj = client.get("/service", {'param':"testing"})
self.assertEqual(obj.method, 'GET')
self.assertEqual(obj.url, "https://my_endpoint.com/service?param=testing")
self.assertEqual(obj.params, {})
self.assertEqual(obj.url, "http://127.0.0.1/service?param=testing")
obj = client.get("service 2")
self.assertEqual(obj.method, 'GET')
self.assertEqual(obj.url, "https://my_endpoint.com/service 2")
self.assertEqual(obj.url, "http://127.0.0.1/service 2")
self.cfg.base_url = "https://my_endpoint.com/"
cfg.base_url = "https://127.0.0.1/"
obj = client.get("//service3")
self.assertEqual(obj.method, 'GET')
self.assertEqual(obj.url, "https://my_endpoint.com/service3")
self.assertEqual(obj.url, "https://127.0.0.1/service3")
obj = client.put()
obj = client.put('/')
self.assertEqual(obj.method, 'PUT')
obj = client.post()
obj = client.post('/')
self.assertEqual(obj.method, 'POST')
obj = client.head()
obj = client.head('/')
self.assertEqual(obj.method, 'HEAD')
obj = client.merge()
obj = client.merge('/')
self.assertEqual(obj.method, 'MERGE')
obj = client.patch()
obj = client.patch('/')
self.assertEqual(obj.method, 'PATCH')
obj = client.delete()
obj = client.delete('/')
self.assertEqual(obj.method, 'DELETE')
def test_format_url(self):
@ -288,53 +249,60 @@ class TestServiceClient(unittest.TestCase):
def test_client_send(self):
class MockHTTPDriver(object):
def configure_session(self, **config):
pass
def send(self, request, **config):
pass
current_ua = self.cfg.user_agent
client = ServiceClient(self.creds, self.cfg)
client.config.keep_alive = True
req_response = requests.Response()
req_response._content = br'{"real": true}' # Has to be valid bytes JSON
req_response._content_consumed = True
req_response.status_code = 200
def side_effect(*args, **kwargs):
return req_response
session = mock.create_autospec(requests.Session)
client._http_driver.session = session
session.request.side_effect = side_effect
session.adapters = {
"http://": HTTPAdapter(),
"https://": HTTPAdapter(),
}
client.creds.signed_session.return_value = session
client.creds.refresh_session.return_value = session
# Be sure the mock does not trick me
assert not hasattr(session.resolve_redirects, 'is_mrest_patched')
assert not hasattr(session.resolve_redirects, 'is_msrest_patched')
request = ClientRequest('GET')
client.config.pipeline._sender.driver.session = session
client._creds.signed_session.return_value = session
client._creds.refresh_session.return_value = session
request = ClientRequest('GET', '/')
client.send(request, stream=False)
session.request.call_count = 0
session.request.assert_called_with(
'GET',
None,
'/',
allow_redirects=True,
cert=None,
headers={
'User-Agent': self.cfg.user_agent,
'User-Agent': current_ua,
'Test': 'true' # From global config
},
stream=False,
timeout=100,
verify=True
)
assert session.resolve_redirects.is_mrest_patched
assert session.resolve_redirects.is_msrest_patched
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, stream=False)
session.request.assert_called_with(
'GET',
None,
'/',
data='{"Test": "Data"}',
allow_redirects=True,
cert=None,
headers={
'User-Agent': self.cfg.user_agent,
'User-Agent': current_ua,
'Content-Length': '16',
'id':'1234',
'Test': 'true' # From global config
@ -345,19 +313,19 @@ class TestServiceClient(unittest.TestCase):
)
self.assertEqual(session.request.call_count, 1)
session.request.call_count = 0
assert session.resolve_redirects.is_mrest_patched
assert session.resolve_redirects.is_msrest_patched
session.request.side_effect = requests.RequestException("test")
with self.assertRaises(ClientRequestError):
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value', stream=False)
session.request.assert_called_with(
'GET',
None,
'/',
data='{"Test": "Data"}',
allow_redirects=True,
cert=None,
headers={
'User-Agent': self.cfg.user_agent,
'User-Agent': current_ua,
'Content-Length': '16',
'id':'1234',
'Test': 'true' # From global config
@ -368,7 +336,7 @@ class TestServiceClient(unittest.TestCase):
)
self.assertEqual(session.request.call_count, 1)
session.request.call_count = 0
assert session.resolve_redirects.is_mrest_patched
assert session.resolve_redirects.is_msrest_patched
session.request.side_effect = oauth2.rfc6749.errors.InvalidGrantError("test")
with self.assertRaises(TokenExpiredError):
@ -380,77 +348,95 @@ class TestServiceClient(unittest.TestCase):
with self.assertRaises(ValueError):
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
def test_client_formdata_add(self):
@mock.patch.object(ClientRequest, "_format_data")
def test_client_formdata_add(self, format_data):
format_data.return_value = "formatted"
client = mock.create_autospec(ServiceClient)
client._format_data.return_value = "formatted"
request = ClientRequest('GET')
ServiceClient._add_formdata(client, request)
request = ClientRequest('GET', '/')
request.add_formdata()
assert request.files == {}
request = ClientRequest('GET')
ServiceClient._add_formdata(client, request, {'Test':'Data'})
request = ClientRequest('GET', '/')
request.add_formdata({'Test':'Data'})
assert request.files == {'Test':'formatted'}
request = ClientRequest('GET')
request = ClientRequest('GET', '/')
request.headers = {'Content-Type':'1234'}
ServiceClient._add_formdata(client, request, {'1':'1', '2':'2'})
request.add_formdata({'1':'1', '2':'2'})
assert request.files == {'1':'formatted', '2':'formatted'}
request = ClientRequest('GET')
request = ClientRequest('GET', '/')
request.headers = {'Content-Type':'1234'}
ServiceClient._add_formdata(client, request, {'1':'1', '2':None})
request.add_formdata({'1':'1', '2':None})
assert request.files == {'1':'formatted'}
request = ClientRequest('GET')
request = ClientRequest('GET', '/')
request.headers = {'Content-Type':'application/x-www-form-urlencoded'}
ServiceClient._add_formdata(client, request, {'1':'1', '2':'2'})
assert request.files == []
request.add_formdata({'1':'1', '2':'2'})
assert request.files is None
assert request.data == {'1':'1', '2':'2'}
request = ClientRequest('GET')
request = ClientRequest('GET', '/')
request.headers = {'Content-Type':'application/x-www-form-urlencoded'}
ServiceClient._add_formdata(client, request, {'1':'1', '2':None})
assert request.files == []
request.add_formdata({'1':'1', '2':None})
assert request.files is None
assert request.data == {'1':'1'}
def test_format_data(self):
mock_client = mock.create_autospec(ServiceClient)
data = ServiceClient._format_data(mock_client, None)
data = ClientRequest._format_data(None)
self.assertEqual(data, (None, None))
data = ServiceClient._format_data(mock_client, "Test")
data = ClientRequest._format_data("Test")
self.assertEqual(data, (None, "Test"))
mock_stream = mock.create_autospec(io.BytesIO)
data = ServiceClient._format_data(mock_client, mock_stream)
data = ClientRequest._format_data(mock_stream)
self.assertEqual(data, (None, mock_stream, "application/octet-stream"))
mock_stream.name = "file_name"
data = ServiceClient._format_data(mock_client, mock_stream)
data = ClientRequest._format_data(mock_stream)
self.assertEqual(data, ("file_name", mock_stream, "application/octet-stream"))
def test_client_stream_download(self):
req_response = requests.Response()
req_response._content = "abc"
req_response._content_consumed = True
req_response.status_code = 200
client_response = RequestsClientResponse(
None,
req_response
)
def user_callback(chunk, response):
assert response is req_response
assert chunk in ["a", "b", "c"]
sync_iterator = client_response.stream_download(1, user_callback)
result = ""
for value in sync_iterator:
result += value
assert result == "abc"
def test_request_builder(self):
client = ServiceClient(self.creds, self.cfg)
req = client.get('http://example.org')
req = client.get('http://127.0.0.1/')
assert req.method == 'GET'
assert req.url == 'http://example.org'
assert req.params == {}
assert req.url == 'http://127.0.0.1/'
assert req.headers == {'Accept': 'application/json'}
assert req.data == []
assert req.files == []
assert req.data is None
assert req.files is None
req = client.put('http://example.org', content={'creation': True})
req = client.put("http://127.0.0.1/", content={'creation': True})
assert req.method == 'PUT'
assert req.url == 'http://example.org'
assert req.params == {}
assert req.url == "http://127.0.0.1/"
assert req.headers == {'Content-Length': '18', 'Accept': 'application/json'}
assert req.data == '{"creation": true}'
assert req.files == []
assert req.files is None
if __name__ == '__main__':
unittest.main()
unittest.main()

Просмотреть файл

@ -1,6 +1,6 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
@ -89,15 +89,16 @@ class TestExceptions(unittest.TestCase):
'ErrorDetails': ErrorDetails
})
response = mock.create_autospec(requests.Response)
response.text = json.dumps(
response = requests.Response()
response._content_consumed = True
response._content = json.dumps(
{
"error": {
"code": "NotOptedIn",
"message": "You are not allowed to download invoices. Please contact your account administrator to turn on access in the management portal for allowing to download invoices through the API."
}
}
)
}
).encode('utf-8')
response.headers = {"content-type": "application/json; charset=utf8"}
excep = ErrorResponseException(deserializer, response)

Просмотреть файл

@ -1,6 +1,6 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
@ -34,20 +34,53 @@ try:
except ImportError:
import mock
import xml.etree.ElementTree as ET
import sys
from msrest.pipeline import (
import pytest
from msrest.universal_http import (
ClientRequest,
ClientRawResponse)
)
from msrest.pipeline import (
ClientRawResponse,
SansIOHTTPPolicy,
Pipeline,
HTTPSender
)
from msrest import Configuration
def test_sans_io_exception():
class BrokenSender(HTTPSender):
def send(self, request, **config):
raise ValueError("Broken")
def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
return None
pipeline = Pipeline([SansIOHTTPPolicy()], BrokenSender())
req = ClientRequest('GET', '/')
with pytest.raises(ValueError):
pipeline.run(req)
class SwapExec(SansIOHTTPPolicy):
def on_exception(self, requests, **kwargs):
exc_type, exc_value, exc_traceback = sys.exc_info()
raise NotImplementedError(exc_value)
pipeline = Pipeline([SwapExec()], BrokenSender())
with pytest.raises(NotImplementedError):
pipeline.run(req)
class TestClientRequest(unittest.TestCase):
def test_request_data(self):
request = ClientRequest()
request = ClientRequest('GET', '/')
data = "Lots of dataaaa"
request.add_content(data)
@ -55,7 +88,7 @@ class TestClientRequest(unittest.TestCase):
self.assertEqual(request.headers.get('Content-Length'), '17')
def test_request_xml(self):
request = ClientRequest()
request = ClientRequest('GET', '/')
data = ET.Element("root")
request.add_content(data)
@ -63,7 +96,7 @@ class TestClientRequest(unittest.TestCase):
def test_request_url_with_params(self):
request = ClientRequest()
request = ClientRequest('GET', '/')
request.url = "a/b/c?t=y"
request.format_parameters({'g': 'h'})

Просмотреть файл

@ -1,6 +1,6 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
@ -34,6 +34,7 @@ import pytest
from msrest.polling import *
from msrest.service_client import ServiceClient
from msrest.serialization import Model
from msrest.configuration import Configuration
def test_abc_polling():
@ -103,11 +104,13 @@ class PollingTwoSteps(PollingMethod):
def resource(self):
return self._deserialization_callback(self._initial_response)
def test_poller():
@pytest.fixture
def client():
# We need a ServiceClient instance, but the poller itself don't use it, so we don't need
# Something functionnal
client = ServiceClient(None, None)
return ServiceClient(None, Configuration("http://example.org"))
def test_poller(client):
# Same the poller itself doesn't care about the initial_response, and there is no type constraint here
initial_response = "Initial response"
@ -115,7 +118,7 @@ def test_poller():
# Same for deserialization_callback, just pass to the polling_method
def deserialization_callback(response):
assert response == initial_response
return "Treated: "+response
return "Treated: "+response
method = NoPolling()
@ -135,7 +138,7 @@ def test_poller():
assert poller._polling_method._deserialization_callback == Model.deserialize
# Test poller that method do a run
method = PollingTwoSteps(sleep=2)
method = PollingTwoSteps(sleep=1)
poller = LROPoller(client, initial_response, deserialization_callback, method)
done_cb = mock.MagicMock()
@ -153,7 +156,7 @@ def test_poller():
poller.remove_done_callback(done_cb)
assert "Process is complete" in str(excinfo.value)
def test_broken_poller():
def test_broken_poller(client):
with pytest.raises(ValueError):
LROPoller(None, None, None, None)
@ -162,10 +165,9 @@ def test_broken_poller():
def run(self):
raise ValueError("Something bad happened")
client = ServiceClient(None, None)
initial_response = "Initial response"
def deserialization_callback(response):
return "Treated: "+response
return "Treated: "+response
method = NoPollingError()
poller = LROPoller(client, initial_response, deserialization_callback, method)
@ -173,4 +175,3 @@ def test_broken_poller():
with pytest.raises(ValueError) as excinfo:
poller.result()
assert "Something bad happened" in str(excinfo.value)

Просмотреть файл

@ -0,0 +1,108 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
import concurrent.futures
from requests.adapters import HTTPAdapter
from msrest.universal_http import (
ClientRequest
)
from msrest.universal_http.requests import (
BasicRequestsHTTPSender,
RequestsHTTPSender,
RequestHTTPSenderConfiguration
)
def test_session_callback():
cfg = RequestHTTPSenderConfiguration()
with RequestsHTTPSender(cfg) as driver:
def callback(session, global_config, local_config, **kwargs):
assert session is driver.session
assert global_config is cfg
assert local_config["test"]
my_kwargs = kwargs.copy()
my_kwargs.update({'used_callback': True})
return my_kwargs
cfg.session_configuration_callback = callback
request = ClientRequest('GET', 'http://127.0.0.1/')
output_kwargs = driver._configure_send(request, **{"test": True})
assert output_kwargs['used_callback']
def test_max_retries_on_default_adapter():
# max_retries must be applied only on the default adapters of requests
# If the user adds its own adapter, don't touch it
cfg = RequestHTTPSenderConfiguration()
max_retries = cfg.retry_policy()
with RequestsHTTPSender(cfg) as driver:
request = ClientRequest('GET', '/')
driver.session.mount('"http://127.0.0.1/"', HTTPAdapter())
driver._configure_send(request)
assert driver.session.adapters["http://"].max_retries is max_retries
assert driver.session.adapters["https://"].max_retries is max_retries
assert driver.session.adapters['"http://127.0.0.1/"'].max_retries is not max_retries
def test_threading_basic_requests():
# Basic should have the session for all threads, it's why it's not recommended
sender = BasicRequestsHTTPSender()
main_thread_session = sender.session
def thread_body(local_sender):
# Should be the same session
assert local_sender.session is main_thread_session
return True
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(thread_body, sender)
assert future.result()
def test_threading_cfg_requests():
cfg = RequestHTTPSenderConfiguration()
# The one with conf however, should have one session per thread automatically
sender = RequestsHTTPSender(cfg)
main_thread_session = sender.session
# Check that this main session is patched
assert main_thread_session.resolve_redirects.is_msrest_patched
def thread_body(local_sender):
# Should have it's own session
assert local_sender.session is not main_thread_session
# But should be patched as the main thread session
assert local_sender.session.resolve_redirects.is_msrest_patched
return True
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(thread_body, sender)
assert future.result()

Просмотреть файл

@ -1,6 +1,6 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
@ -45,8 +45,9 @@ except ImportError:
from msrest.authentication import (
Authentication,
OAuthTokenAuthentication)
from msrest.pipeline import (
ClientRequest)
from msrest.universal_http import (
ClientRequest
)
from msrest import (
ServiceClient,
Configuration)
@ -54,12 +55,13 @@ from msrest.exceptions import (
TokenExpiredError,
ClientRequestError)
import pytest
class TestRuntime(unittest.TestCase):
@httpretty.activate
def test_credential_headers(self):
httpretty.register_uri(httpretty.GET, "https://my_service.com/get_endpoint",
body='[{"title": "Test Data"}]',
content_type="application/json")
@ -79,10 +81,12 @@ class TestRuntime(unittest.TestCase):
url = client.format_url("/get_endpoint")
request = client.get(url, {'check':True})
response = client.send(request)
self.assertTrue('Authorization' in response.request.headers)
self.assertEqual(response.request.headers['Authorization'], 'Bearer eswfld123kjhn1v5423')
check = httpretty.last_request()
self.assertEqual(response.json(), [{"title": "Test Data"}])
assert 'Authorization' in response.request.headers
assert response.request.headers['Authorization'] == 'Bearer eswfld123kjhn1v5423'
httpretty.has_request()
assert response.json() == [{"title": "Test Data"}]
# Expiration test
token['expires_in'] = '-30'
creds = OAuthTokenAuthentication("client_id", token)
@ -90,13 +94,13 @@ class TestRuntime(unittest.TestCase):
url = client.format_url("/get_endpoint")
request = client.get(url, {'check':True})
with self.assertRaises(TokenExpiredError):
with pytest.raises(TokenExpiredError):
response = client.send(request)
@mock.patch.object(requests, 'Session')
def test_request_fail(self, mock_requests):
mock_requests.return_value.request.return_value = mock.Mock(_content_consumed=True)
mock_requests.return_value.request.return_value = mock.Mock(text="text")
cfg = Configuration("https://my_service.com")
creds = Authentication()
@ -106,9 +110,7 @@ class TestRuntime(unittest.TestCase):
request = client.get(url, {'check':True})
response = client.send(request)
check = httpretty.last_request()
self.assertTrue(response._content_consumed)
assert response.text == "text"
mock_requests.return_value.request.side_effect = requests.RequestException
with self.assertRaises(ClientRequestError):
@ -131,7 +133,7 @@ class TestRuntime(unittest.TestCase):
url = client.format_url("/get_endpoint")
request = client.get(url, {'check':True})
response = client.send(request)
self.assertEqual(response.json(), "Mocked body")
assert response.json() == "Mocked body"
with mock.patch.dict('os.environ', {'HTTP_PROXY': "http://localhost:1987"}):
httpretty.register_uri(httpretty.GET, "http://localhost:1987/get_endpoint?check=True",
@ -144,7 +146,7 @@ class TestRuntime(unittest.TestCase):
url = client.format_url("/get_endpoint")
request = client.get(url, {'check':True})
response = client.send(request)
self.assertEqual(response.json(), "Mocked body")
assert response.json() == "Mocked body"
class TestRedirect(unittest.TestCase):
@ -171,14 +173,14 @@ class TestRedirect(unittest.TestCase):
responses=[
httpretty.Response(body="", status=303, method='POST', location='/http/success/get/200'),
])
response = self.client.send(request)
self.assertEqual(response.status_code, 200, msg="Should redirect with GET on 303 with location header")
self.assertEqual(response.request.method, 'GET')
self.assertEqual(response.history[0].status_code, 303)
self.assertTrue(response.history[0].is_redirect)
response = self.client.send(request)
assert response.status_code == 200, "Should redirect with GET on 303 with location header"
assert response.request.method == 'GET'
assert response.history[0].status_code == 303
assert response.history[0].is_redirect
httpretty.reset()
httpretty.register_uri(httpretty.POST, "https://my_service.com/get_endpoint",
@ -187,9 +189,9 @@ class TestRedirect(unittest.TestCase):
])
response = self.client.send(request)
self.assertEqual(response.status_code, 303, msg="Should not redirect on 303 without location header")
self.assertEqual(response.history, [])
self.assertFalse(response.is_redirect)
assert response.status_code == 303, "Should not redirect on 303 without location header"
assert response.history == []
assert not response.is_redirect
@httpretty.activate
def test_request_redirect_head(self):
@ -202,14 +204,14 @@ class TestRedirect(unittest.TestCase):
responses=[
httpretty.Response(body="", status=307, method='HEAD', location='/http/success/200'),
])
response = self.client.send(request)
self.assertEqual(response.status_code, 200, msg="Should redirect on 307 with location header")
self.assertEqual(response.request.method, 'HEAD')
self.assertEqual(response.history[0].status_code, 307)
self.assertTrue(response.history[0].is_redirect)
response = self.client.send(request)
assert response.status_code == 200, "Should redirect on 307 with location header"
assert response.request.method == 'HEAD'
assert response.history[0].status_code == 307
assert response.history[0].is_redirect
httpretty.reset()
httpretty.register_uri(httpretty.HEAD, "https://my_service.com/get_endpoint",
@ -218,9 +220,9 @@ class TestRedirect(unittest.TestCase):
])
response = self.client.send(request)
self.assertEqual(response.status_code, 307, msg="Should not redirect on 307 without location header")
self.assertEqual(response.history, [])
self.assertFalse(response.is_redirect)
assert response.status_code == 307, "Should not redirect on 307 without location header"
assert response.history == []
assert not response.is_redirect
@httpretty.activate
def test_request_redirect_delete(self):
@ -233,14 +235,14 @@ class TestRedirect(unittest.TestCase):
responses=[
httpretty.Response(body="", status=307, method='DELETE', location='/http/success/200'),
])
response = self.client.send(request)
self.assertEqual(response.status_code, 200, msg="Should redirect on 307 with location header")
self.assertEqual(response.request.method, 'DELETE')
self.assertEqual(response.history[0].status_code, 307)
self.assertTrue(response.history[0].is_redirect)
response = self.client.send(request)
assert response.status_code == 200, "Should redirect on 307 with location header"
assert response.request.method == 'DELETE'
assert response.history[0].status_code == 307
assert response.history[0].is_redirect
httpretty.reset()
httpretty.register_uri(httpretty.DELETE, "https://my_service.com/get_endpoint",
@ -249,9 +251,9 @@ class TestRedirect(unittest.TestCase):
])
response = self.client.send(request)
self.assertEqual(response.status_code, 307, msg="Should not redirect on 307 without location header")
self.assertEqual(response.history, [])
self.assertFalse(response.is_redirect)
assert response.status_code == 307, "Should not redirect on 307 without location header"
assert response.history == []
assert not response.is_redirect
@httpretty.activate
def test_request_redirect_put(self):
@ -265,9 +267,9 @@ class TestRedirect(unittest.TestCase):
])
response = self.client.send(request)
self.assertEqual(response.status_code, 305, msg="Should not redirect on 305")
self.assertEqual(response.history, [])
self.assertFalse(response.is_redirect)
assert response.status_code == 305, "Should not redirect on 305"
assert response.history == []
assert not response.is_redirect
@httpretty.activate
def test_request_redirect_get(self):
@ -301,7 +303,7 @@ class TestRedirect(unittest.TestCase):
])
with self.assertRaises(ClientRequestError, msg="Should exceed maximum redirects"):
response = self.client.send(request)
self.client.send(request)
@ -325,8 +327,8 @@ class TestRuntimeRetry(unittest.TestCase):
httpretty.Response(body="retry response", status=502),
httpretty.Response(body='success response', status=202),
])
response = self.client.send(self.request)
self.assertEqual(response.status_code, 202, msg="Should retry on 502")
@ -364,8 +366,8 @@ class TestRuntimeRetry(unittest.TestCase):
])
with self.assertRaises(ClientRequestError, msg="Max retries reached"):
response = self.client.send(self.request)
self.client.send(self.request)
@httpretty.activate
def test_request_retry_404(self):
httpretty.register_uri(httpretty.GET, "https://my_service.com/get_endpoint",
@ -399,6 +401,6 @@ class TestRuntimeRetry(unittest.TestCase):
response = self.client.send(self.request)
self.assertEqual(response.status_code, 505, msg="Shouldn't retry on 505")
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -31,10 +31,6 @@ import logging
from enum import Enum
from datetime import datetime, timedelta, date
import unittest
try:
from unittest import mock
except ImportError:
import mock
import xml.etree.ElementTree as ET
@ -180,10 +176,7 @@ class TestModelDeserialization(unittest.TestCase):
"location": "westus"
}
resp = mock.create_autospec(Response)
resp.text = json.dumps(data)
resp.headers = {"content-type": "application/json; charset=utf8"}
model = self.d('GenericResource', resp)
model = self.d('GenericResource', json.dumps(data), 'application/json')
self.assertEqual(model.properties['platformFaultDomainCount'], 3)
self.assertEqual(model.location, 'westus')
@ -1305,47 +1298,6 @@ class TestRuntimeDeserialized(unittest.TestCase):
self.d = Deserializer()
return super(TestRuntimeDeserialized, self).setUp()
def test_unpack(self):
result = Deserializer._unpack_content("<groot/>", content_type="application/xml")
assert result.tag == "groot"
# Catch some weird situation where content_type is XML, but content is JSON
result = Deserializer._unpack_content('{"ugly": true}', content_type="application/xml")
assert result["ugly"] is True
# Be sure I catch the correct exception if it's neither XML nor JSON
with pytest.raises(ET.ParseError):
result = Deserializer._unpack_content('gibberish', content_type="application/xml")
with pytest.raises(ET.ParseError):
result = Deserializer._unpack_content('{{gibberish}}', content_type="application/xml")
result = Deserializer._unpack_content('{"success": true}', content_type="application/json")
assert result["success"] is True
# For compat, if no content-type, and direct string, just return it
result = Deserializer._unpack_content('data')
assert result == "data"
# Decore bytes
result = Deserializer._unpack_content(b'data')
assert result == "data"
response = Response()
response.headers["content-type"] = "application/json"
response._content = b'{"success": true}'
response._content_consumed = True
result = Deserializer._unpack_content(response)
assert result["success"] is True
# If no content-type, assume it's JSON
response = Response()
response._content = b'{"success": true}'
response._content_consumed = True
result = Deserializer._unpack_content(response)
assert result["success"] is True
def test_cls_method_deserialization(self):
json_data = {
'id': 'myid',
@ -1517,80 +1469,56 @@ class TestRuntimeDeserialized(unittest.TestCase):
"""
Test invalid JSON
"""
response_data = mock.create_autospec(Response)
response_data.headers = {"content-type": "application/json; charset=utf8"}
response_data.text = '["tata"]]'
with self.assertRaises(DeserializationError):
self.d("[str]", response_data)
self.d("[str]", '["tata"]]', 'application/json')
def test_non_obj_deserialization(self):
"""
Test direct deserialization of simple types.
"""
response_data = mock.create_autospec(Response)
response_data.headers = {"content-type": "application/json; charset=utf8"}
response_data.text = ''
with self.assertRaises(DeserializationError):
self.d("[str]", response_data)
self.d("[str]", '', 'application/json')
response_data.text = json.dumps('')
with self.assertRaises(DeserializationError):
self.d("[str]", response_data)
self.d("[str]", json.dumps(''), 'application/json')
response_data.text = json.dumps({})
with self.assertRaises(DeserializationError):
self.d("[str]", response_data)
self.d("[str]", json.dumps({}), 'application/json')
message = ["a","b","b"]
response_data.text = json.dumps(message)
response = self.d("[str]", response_data)
response = self.d("[str]", json.dumps(message), 'application/json')
self.assertEqual(response, message)
response_data.text = json.dumps(12345)
with self.assertRaises(DeserializationError):
self.d("[str]", response_data)
self.d("[str]", json.dumps(12345), 'application/json')
response_data.text = json.dumps('true')
response = self.d('bool', response_data)
response = self.d('bool', json.dumps('true'), 'application/json')
self.assertEqual(response, True)
response_data.text = json.dumps(1)
response = self.d('bool', response_data)
response = self.d('bool', json.dumps(1), 'application/json')
self.assertEqual(response, True)
response_data.text = json.dumps("true1")
with self.assertRaises(DeserializationError):
self.d('bool', response_data)
self.d('bool', json.dumps("true1"), 'application/json')
def test_obj_with_no_attr(self):
"""
Test deserializing an object with no attributes.
"""
response_data = mock.create_autospec(Response)
response_data.text = json.dumps({"a":"b"})
response_data.headers = {"content-type": "application/json; charset=utf8"}
class EmptyResponse(Model):
_attribute_map = {}
_header_map = {}
derserialized = self.d(EmptyResponse, response_data)
self.assertIsInstance(derserialized, EmptyResponse)
deserialized = self.d(EmptyResponse, json.dumps({"a":"b"}), 'application/json')
self.assertIsInstance(deserialized, EmptyResponse)
def test_obj_with_malformed_map(self):
"""
Test deserializing an object with a malformed attributes_map.
"""
response_data = mock.create_autospec(Response)
response_data.text = json.dumps({"a":"b"})
response_data.headers = {"content-type": "application/json; charset=utf8"}
class BadResponse(Model):
_attribute_map = None
@ -1598,7 +1526,7 @@ class TestRuntimeDeserialized(unittest.TestCase):
pass
with self.assertRaises(DeserializationError):
self.d(BadResponse, response_data)
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
class BadResponse(Model):
_attribute_map = {"attr":"val"}
@ -1607,7 +1535,7 @@ class TestRuntimeDeserialized(unittest.TestCase):
pass
with self.assertRaises(DeserializationError):
self.d(BadResponse, response_data)
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
class BadResponse(Model):
_attribute_map = {"attr":{"val":1}}
@ -1616,141 +1544,89 @@ class TestRuntimeDeserialized(unittest.TestCase):
pass
with self.assertRaises(DeserializationError):
self.d(BadResponse, response_data)
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
def test_attr_none(self):
"""
Test serializing an object with None attributes.
"""
response_data = mock.create_autospec(Response)
response_data.headers = {"content-type": "application/json; charset=utf8"}
response_data.text = 'null'
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, 'null', 'application/json')
self.assertIsNone(response)
def test_attr_int(self):
"""
Test deserializing an object with Int attributes.
"""
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id':"123",
'etag':456.3,
"content-type": "application/json; charset=utf8"
}
response_data.text = ''
message = {'AttrB':'1234'}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertTrue(hasattr(response, 'attr_b'))
self.assertEqual(response.attr_b, int(message['AttrB']))
with self.assertRaises(DeserializationError):
response_data.text = json.dumps({'AttrB':'NotANumber'})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrB':'NotANumber'}), 'application/json')
def test_attr_str(self):
"""
Test deserializing an object with Str attributes.
"""
message = {'id':'InterestingValue'}
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id': 'a',
'etag': 'b',
"content-type": "application/json; charset=utf8"
}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertTrue(hasattr(response, 'attr_a'))
self.assertEqual(response.attr_a, message['id'])
message = {'id':1234}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertEqual(response.attr_a, str(message['id']))
message = {'id':list()}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertEqual(response.attr_a, str(message['id']))
response_data.text = json.dumps({'id':None})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'id':None}), 'application/json')
self.assertEqual(response.attr_a, None)
def test_attr_bool(self):
"""
Test deserializing an object with bool attributes.
"""
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id': 'a',
'etag': 'b',
"content-type": "application/json; charset=utf8"
}
response_data.text = json.dumps({'Key_C':True})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'Key_C':True}), 'application/json')
self.assertTrue(hasattr(response, 'attr_c'))
self.assertEqual(response.attr_c, True)
response_data.text = json.dumps({'Key_C':[]})
with self.assertRaises(DeserializationError):
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'Key_C':[]}), 'application/json')
response_data.text = json.dumps({'Key_C':0})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'Key_C':0}), 'application/json')
self.assertEqual(response.attr_c, False)
response_data.text = json.dumps({'Key_C':"value"})
with self.assertRaises(DeserializationError):
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'Key_C':"value"}), 'application/json')
def test_attr_list_simple(self):
"""
Test deserializing an object with simple-typed list attributes
"""
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id': 'a',
'etag': 'b',
"content-type": "application/json; charset=utf8"
}
response_data.text = json.dumps({'AttrD': []})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrD': []}), 'application/json')
deserialized_list = [d for d in response.attr_d]
self.assertEqual(deserialized_list, [])
message = {'AttrD': [1,2,3]}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
deserialized_list = [d for d in response.attr_d]
self.assertEqual(deserialized_list, message['AttrD'])
message = {'AttrD': ["1","2","3"]}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
deserialized_list = [d for d in response.attr_d]
self.assertEqual(deserialized_list, [int(i) for i in message['AttrD']])
response_data.text = json.dumps({'AttrD': ["test","test2","test3"]})
with self.assertRaises(DeserializationError):
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrD': ["test","test2","test3"]}), 'application/json')
deserialized_list = [d for d in response.attr_d]
response_data.text = json.dumps({'AttrD': "NotAList"})
with self.assertRaises(DeserializationError):
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrD': "NotAList"}), 'application/json')
deserialized_list = [d for d in response.attr_d]
self.assertListEqual(sorted(self.d("[str]", ["a", "b", "c"])), ["a", "b", "c"])
@ -1760,49 +1636,30 @@ class TestRuntimeDeserialized(unittest.TestCase):
"""
Test deserializing a list of lists
"""
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id': 'a',
'etag': 'b',
"content-type": "application/json; charset=utf8"
}
response_data.text = json.dumps({'AttrF':[]})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrF':[]}), 'application/json')
self.assertTrue(hasattr(response, 'attr_f'))
self.assertEqual(response.attr_f, [])
response_data.text = json.dumps({'AttrF':None})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrF':None}), 'application/json')
self.assertTrue(hasattr(response, 'attr_f'))
self.assertEqual(response.attr_f, None)
response_data.text = json.dumps({})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({}), 'application/json')
self.assertTrue(hasattr(response, 'attr_f'))
self.assertEqual(response.attr_f, None)
message = {'AttrF':[[]]}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertTrue(hasattr(response, 'attr_f'))
self.assertEqual(response.attr_f, message['AttrF'])
message = {'AttrF':[[1,2,3], ['a','b','c']]}
response_data.text = json.dumps(message)
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps(message), 'application/json')
self.assertTrue(hasattr(response, 'attr_f'))
self.assertEqual(response.attr_f, [[str(i) for i in k] for k in message['AttrF']])
with self.assertRaises(DeserializationError):
response_data.text = json.dumps({'AttrF':[1,2,3]})
response = self.d(self.TestObj, response_data)
response = self.d(self.TestObj, json.dumps({'AttrF':[1,2,3]}), 'application/json')
def test_attr_list_complex(self):
"""
@ -1816,17 +1673,8 @@ class TestRuntimeDeserialized(unittest.TestCase):
_attribute_map = {'attr_a': {'key':'id', 'type':'[ListObj]'}}
response_data = mock.create_autospec(Response)
response_data.status_code = 200
response_data.headers = {
'client-request-id': 'a',
'etag': 'b',
"content-type": "application/json; charset=utf8"
}
response_data.text = json.dumps({"id":[{"ABC": "123"}]})
d = Deserializer({'ListObj':ListObj})
response = d(CmplxTestObj, response_data)
response = d(CmplxTestObj, json.dumps({"id":[{"ABC": "123"}]}), 'application/json')
deserialized_list = list(response.attr_a)
self.assertIsInstance(deserialized_list[0], ListObj)

Просмотреть файл

@ -0,0 +1,164 @@
#--------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#--------------------------------------------------------------------------
try:
from unittest import mock
except ImportError:
import mock
import requests
import pytest
from msrest.exceptions import DeserializationError
from msrest.universal_http import (
ClientRequest,
ClientResponse,
HTTPClientResponse,
)
from msrest.universal_http.requests import RequestsClientResponse
from msrest.pipeline import (
Response,
Request
)
from msrest.pipeline.universal import (
HTTPLogger,
RawDeserializer,
UserAgentPolicy
)
def test_user_agent():
with mock.patch.dict('os.environ', {'AZURE_HTTP_USER_AGENT': "mytools"}):
policy = UserAgentPolicy()
assert policy.user_agent.endswith("mytools")
request = ClientRequest('GET', 'http://127.0.0.1/')
policy.on_request(Request(request))
assert request.headers["user-agent"].endswith("mytools")
@mock.patch('msrest.http_logger._LOGGER')
def test_no_log(mock_http_logger):
universal_request = ClientRequest('GET', 'http://127.0.0.1/')
request = Request(universal_request)
http_logger = HTTPLogger()
response = Response(request, ClientResponse(universal_request, None))
# By default, no log handler for HTTP
http_logger.on_request(request)
mock_http_logger.debug.assert_not_called()
http_logger.on_response(request, response)
mock_http_logger.debug.assert_not_called()
mock_http_logger.reset_mock()
# I can enable it per request
http_logger.on_request(request, **{"enable_http_logger": True})
assert mock_http_logger.debug.call_count >= 1
http_logger.on_response(request, response, **{"enable_http_logger": True})
assert mock_http_logger.debug.call_count >= 1
mock_http_logger.reset_mock()
# I can enable it per request (bool value should be honored)
http_logger.on_request(request, **{"enable_http_logger": False})
mock_http_logger.debug.assert_not_called()
http_logger.on_response(request, response, **{"enable_http_logger": False})
mock_http_logger.debug.assert_not_called()
mock_http_logger.reset_mock()
# I can enable it globally
http_logger.enable_http_logger = True
http_logger.on_request(request)
assert mock_http_logger.debug.call_count >= 1
http_logger.on_response(request, response)
assert mock_http_logger.debug.call_count >= 1
mock_http_logger.reset_mock()
# I can enable it globally and override it locally
http_logger.enable_http_logger = True
http_logger.on_request(request, **{"enable_http_logger": False})
mock_http_logger.debug.assert_not_called()
http_logger.on_response(request, response, **{"enable_http_logger": False})
mock_http_logger.debug.assert_not_called()
mock_http_logger.reset_mock()
def test_raw_deserializer():
raw_deserializer = RawDeserializer()
def build_response(body, content_type=None):
class MockResponse(HTTPClientResponse):
def __init__(self, body, content_type):
super(MockResponse, self).__init__(None, None)
self._body = body
if content_type:
self.headers['content-type'] = content_type
def body(self):
return self._body
return Response(None, MockResponse(body, content_type))
response = build_response(b"<groot/>", content_type="application/xml")
raw_deserializer.on_response(None, response, stream=False)
result = response.context["deserialized_data"]
assert result.tag == "groot"
# Catch some weird situation where content_type is XML, but content is JSON
response = build_response(b'{"ugly": true}', content_type="application/xml")
raw_deserializer.on_response(None, response, stream=False)
result = response.context["deserialized_data"]
assert result["ugly"] is True
# Be sure I catch the correct exception if it's neither XML nor JSON
with pytest.raises(DeserializationError):
response = build_response(b'gibberish', content_type="application/xml")
raw_deserializer.on_response(None, response, stream=False)
with pytest.raises(DeserializationError):
response = build_response(b'{{gibberish}}', content_type="application/xml")
raw_deserializer.on_response(None, response, stream=False)
# Simple JSON
response = build_response(b'{"success": true}', content_type="application/json")
raw_deserializer.on_response(None, response, stream=False)
result = response.context["deserialized_data"]
assert result["success"] is True
# For compat, if no content-type, decode JSON
response = build_response(b'"data"')
raw_deserializer.on_response(None, response, stream=False)
result = response.context["deserialized_data"]
assert result == "data"
# Try with a mock of requests
req_response = requests.Response()
req_response.headers["content-type"] = "application/json"
req_response._content = b'{"success": true}'
req_response._content_consumed = True
response = Response(None, RequestsClientResponse(None, req_response))
raw_deserializer.on_response(None, response, stream=False)
result = response.context["deserialized_data"]
assert result["success"] is True

Просмотреть файл

@ -26,8 +26,6 @@
import sys
import xml.etree.ElementTree as ET
import requests
import pytest
from msrest.serialization import Serializer, Deserializer, Model, xml_key_extractor
@ -140,27 +138,6 @@ class TestXmlDeserialization:
assert child.tag == "Age"
assert child.text == "37"
def test_object_from_requests(self):
basic_xml = b"""<?xml version="1.0"?>
<Data country="france">
<Age>37</Age>
</Data>"""
response = requests.Response()
response.headers["content-type"] = "application/xml; charset=utf-8"
response._content = basic_xml
response._content_consumed = True
s = Deserializer()
result = s('object', response)
# Should be a XML tree
assert result.tag == "Data"
assert result.get("country") == "france"
for child in result:
assert child.tag == "Age"
assert child.text == "37"
def test_basic_empty(self):
"""Test an basic XML with an empty node."""
basic_xml = """<?xml version="1.0"?>

Просмотреть файл

@ -13,4 +13,4 @@ commands=
pytest --cov=msrest tests/
autorest: pytest --cov=msrest --cov-append autorest.python/test/vanilla/
coverage report --fail-under=40
coverage xml
coverage xml --ignore-errors # At this point, don't fail for "async" keyword in 2.7/3.4