Pipeline (#106)
* Pipeline renaming * Work in progress * Work in progress 2 * Move formData to ClientRequest * Pure ClientRequest with no requests * Add kwargs for send * Pipeline update * First pass on tests * Some typehints * Fixing all tests * Full Pipeline mypy * Py3 mock compat * Pipeline and stream download * First pass on Autorest testserver * While making coverage report, don't cry for async * Pylint * Add absolute_import for Py2.7 * Fix ABC for 2.7 * Add absolute_import to another file for 2.7 * Pipeline is a context manager * Some async ABC * Move mypy to 3.6 * Fix empty policies * aiohttp proof of concepts * Simplify ClientResponse * Improve response handling + async fixes * Fix mypy * Create a basic HTTPSender * async dependencies * Fix Pipfile for asyncio * Py3.5 compat * Add universal to async tests * Basic requests as asyncio impl * Remove check_redirect from configuration * Improve default pipeline * Make pipeline a public attribute * Split configuration for pipeline * Refactor config * Restore exception * mypy happiness * Split requests configuration * Simplify keep_alive behavior * Default parameter * Rename to on_request/on_response after feedback * Multi-thread compatible HTTP requests sender * Squashed commit of the following: commit3246847c2f
Merge:18cb696
388e8d0
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Thu Jun 14 15:26:57 2018 -0700 Merge remote-tracking branch 'origin/master' into async2 commit18cb696109
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Wed May 23 11:30:00 2018 -0700 MyPy happiness commitbd7123396b
Merge:a997e97
3a8b79d
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Wed May 23 11:23:44 2018 -0700 Merge branch 'master' into async2 commita997e97cd9
Merge:4130eca
2b7d778
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Wed May 9 13:54:11 2018 -0700 Merge remote-tracking branch 'origin/master' into async2 commit4130eca92a
Merge:8ffedd8
9d81113
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Fri Apr 20 10:38:45 2018 -0700 Merge remote-tracking branch 'origin/master' into async2 commit8ffedd8a3a
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Mar 20 15:36:40 2018 -0700 Refactor a little async stream download commitbbf1259ca8
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Fri Mar 16 17:20:07 2018 -0700 Add stream upload support commit2d260036f6
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Feb 27 16:25:33 2018 -0800 Fix incorrect request call commit6b55d4f633
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Wed Jan 17 13:39:18 2018 -0800 Add status/finished to async poller commit02c333eb13
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jan 16 15:50:06 2018 -0800 Port stream to async implementation commitb3f0ac7d29
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jan 16 15:32:23 2018 -0800 Add AsyncPoller commit3e9e17883e
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Thu Dec 7 16:13:06 2017 -0800 Sync ServiceClientAsync with latest fixes commit5483e289b5
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Thu Jul 20 11:27:05 2017 -0700 Address feedback from @brettcannon on async commitc99f4b71a7
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 13:14:27 2017 -0700 Robust coverage xml report commite0c6d3e42b
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 12:23:12 2017 -0700 Rename SC mixin commit8e029ff0be
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 12:12:12 2017 -0700 Add async_get to paging commitf3dfaf6526
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 12:06:20 2017 -0700 Rename paging mixin commit2f7142d211
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 12:05:34 2017 -0700 async_get_next in paging commit9e821009c4
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 11:20:24 2017 -0700 async send form data commit17045f776e
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 11:18:46 2017 -0700 Add async client mixin commit3294115452
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 11:08:37 2017 -0700 Fix Py3.5 async tests commit615f672aec
Author: Laurent Mazuel <laurent.mazuel@gmail.com> Date: Tue Jul 18 10:31:54 2017 -0700 Async paging with mixin * Remove useless tox line * Add credentials to async requests * Add trio support * Revamp async stream download * Add pipeline Response wrapper * Introduce a raw deserializer as a policy * SansIO on_exception * Fix deserialization tests * Implement #116 - Env variable for UserAgent * Revamp streaming * Logger should not log streamable response * Put pipeline in config * Mypy fixes * Update mypy 0.620 * Fix trio dep * Create universal HTTP package * Pipeline as a universal HTTP implementation, no mypy * Pipeline as a universal HTTP implementation, with mypy * Backward compatible ServiceClient * Doc update * Make config optional in requests engine * Fix types for MyPY 0.630 * Credentials can be directly a Policy * msrest[async] * 0.6.0rc1 first ChangeLog * Don't coverage report the TYPE_CHECKING import * Fix dev install on 3.5 and more * 0.6.0rc1 * Pre-load aiohttp body * sync ServiceClient returns requersts.Response * Fix import issue * Remove concept of requests_kwargs * Fix hooks parameter
This commit is contained in:
Родитель
50c5546691
Коммит
3653d29fc4
|
@ -0,0 +1,4 @@
|
|||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
if TYPE_CHECKING:
|
|
@ -20,11 +20,11 @@ _autorest_install: &_autorest_install
|
|||
jobs:
|
||||
include:
|
||||
- stage: MyPy
|
||||
python: 3.5
|
||||
python: 3.6
|
||||
install:
|
||||
- pip install mypy
|
||||
script:
|
||||
- mypy msrest --ignore-missing-imports
|
||||
- mypy msrest
|
||||
- stage: Test
|
||||
python: 2.7
|
||||
env: TOXENV=py27
|
||||
|
|
4
Pipfile
4
Pipfile
|
@ -4,12 +4,14 @@ verify_ssl = true
|
|||
name = "pypi"
|
||||
|
||||
[packages]
|
||||
"e1839a8" = {path = ".", editable = true}
|
||||
"e1839a8" = {path = ".", extras = ["async"], editable = true}
|
||||
|
||||
[dev-packages]
|
||||
pytest = "*"
|
||||
pytest-cov = "*"
|
||||
pytest-asyncio = {version = "*", markers="python_version >= '3.5'"}
|
||||
httpretty = ">=0.8.10"
|
||||
mock = {version = "*", markers="python_version <= '2.7'"}
|
||||
mypy = {version = "==0.630", markers="python_version > '2.7'"}
|
||||
pylint = "*"
|
||||
trio = {version = "*", markers="python_version >= '3.5'"}
|
||||
|
|
35
README.rst
35
README.rst
|
@ -20,6 +20,41 @@ To install:
|
|||
Release History
|
||||
---------------
|
||||
|
||||
2018-XX-XX Version 0.6.0rc1
|
||||
+++++++++++++++++++++++++++
|
||||
|
||||
**Features**
|
||||
|
||||
- The environment variable AZURE_HTTP_USER_AGENT, if present, is now injected part of the UserAgent
|
||||
- New msrest.universal_http module. Provide tools to generic HTTP management (sync/async, requests/aiohttp, etc.)
|
||||
- New **preview** msrest.pipeline implementation:
|
||||
|
||||
- A Pipeline is an ordered list of Policies than can process an HTTP request and response in a generic way.
|
||||
- More details in the wiki page about Pipeline: https://github.com/Azure/msrest-for-python/wiki/msrest-0.6.0---Pipeline
|
||||
|
||||
- Adding new attribute to Configuration instance:
|
||||
|
||||
- http_logger_policy - Policy to handle HTTP logging
|
||||
- user_agent_policy - Policy to handle HTTP logging
|
||||
- pipeline - The current pipeline used by the SDK client
|
||||
- async_pipeline - The current async pipeline used by the SDK client
|
||||
|
||||
- Installing "msrest[async]" now install the **experimental** async support
|
||||
|
||||
**Breaking changes**
|
||||
|
||||
- The HTTPDriver API introduced in 0.5.0 has been replaced by Pipeline.
|
||||
|
||||
- The following classes have been moved from "msrest.pipeline" to "msrest.universal_http":
|
||||
|
||||
- ClientRedirectPolicy
|
||||
- ClientProxies
|
||||
- ClientConnection
|
||||
|
||||
- The following classes have been moved from "msrest.pipeline" to "msrest.universal_http.requests":
|
||||
|
||||
- ClientRetryPolicy
|
||||
|
||||
2018-09-04 Version 0.5.5
|
||||
++++++++++++++++++++++++
|
||||
|
||||
|
|
|
@ -24,10 +24,10 @@
|
|||
#
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from .version import msrest_version
|
||||
from .configuration import Configuration
|
||||
from .service_client import ServiceClient, SDKClient
|
||||
from .serialization import Serializer, Deserializer
|
||||
from .version import msrest_version
|
||||
|
||||
__all__ = [
|
||||
"ServiceClient",
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List, Union, TYPE_CHECKING
|
||||
|
||||
from .universal_http import ClientRequest
|
||||
from .universal_http.async_requests import AsyncRequestsHTTPSender
|
||||
from .pipeline import Request, AsyncPipeline, AsyncHTTPPolicy, SansIOHTTPPolicy
|
||||
from .pipeline.async_requests import (
|
||||
AsyncPipelineRequestsHTTPSender,
|
||||
AsyncRequestsCredentialsPolicy
|
||||
)
|
||||
from .pipeline.universal import (
|
||||
HTTPLogger,
|
||||
RawDeserializer,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration import Configuration # pylint: disable=unused-import
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncSDKClientMixin:
|
||||
"""The base class of all generated SDK client.
|
||||
"""
|
||||
async def __aenter__(self):
|
||||
await self._client.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details):
|
||||
await self._client.__aexit__(*exc_details)
|
||||
|
||||
|
||||
class AsyncServiceClientMixin:
|
||||
|
||||
def __init__(self, creds: Any, config: 'Configuration') -> None:
|
||||
# Don't do super, since I know it will be "object"
|
||||
# super(AsyncServiceClientMixin, self).__init__(creds, config)
|
||||
|
||||
# "async_pipeline" be should accessible from "config"
|
||||
# In legacy mode this is weird, this config is a parameter of "pipeline"
|
||||
# Should be revamp one day.
|
||||
self.config.async_pipeline = self._create_default_async_pipeline() # type: ignore
|
||||
|
||||
def _create_default_async_pipeline(self):
|
||||
|
||||
policies = [
|
||||
self.config.user_agent_policy, # UserAgent policy
|
||||
RawDeserializer(), # Deserialize the raw bytes
|
||||
self.config.http_logger_policy # HTTP request/response log
|
||||
] # type: List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]]
|
||||
if self._creds:
|
||||
if isinstance(self._creds, (AsyncHTTPPolicy, SansIOHTTPPolicy)):
|
||||
policies.insert(1, self._creds)
|
||||
else:
|
||||
# Assume this is the old credentials class, and then requests. Wrap it.
|
||||
policies.insert(1, AsyncRequestsCredentialsPolicy(self._creds))
|
||||
|
||||
return AsyncPipeline(
|
||||
policies,
|
||||
AsyncPipelineRequestsHTTPSender(
|
||||
AsyncRequestsHTTPSender(self.config) # Send HTTP request using requests
|
||||
)
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.config.async_pipeline.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details):
|
||||
await self.config.async_pipeline.__aexit__(*exc_details)
|
||||
|
||||
async def async_send(self, request, **kwargs):
|
||||
"""Prepare and send request object according to configuration.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
:param dict headers: Any headers to add to the request.
|
||||
:param content: Any body data to add to the request.
|
||||
:param config: Any specific config overrides
|
||||
"""
|
||||
kwargs.setdefault('stream', True)
|
||||
# In the current backward compatible implementation, return the HTTP response
|
||||
# and plug context inside. Could be remove if we modify Autorest,
|
||||
# but we still need it to be backward compatible
|
||||
pipeline_response = await self.config.async_pipeline.run(request, **kwargs)
|
||||
response = pipeline_response.http_response
|
||||
response.context = pipeline_response.context
|
||||
return response
|
||||
|
||||
def stream_download_async(self, response, user_callback):
|
||||
"""Async Generator for streaming request body data.
|
||||
|
||||
:param response: The initial response
|
||||
:param user_callback: Custom callback for monitoring progress.
|
||||
"""
|
||||
block = self.config.connection.data_block_size
|
||||
return response.stream_download(block, user_callback)
|
|
@ -0,0 +1,74 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from collections.abc import AsyncIterator
|
||||
import logging
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class AsyncPagedMixin(AsyncIterator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Bring async to Paging.
|
||||
|
||||
"async_command" is mandatory keyword argument for this mixin to work.
|
||||
"""
|
||||
self._async_get_next = kwargs.get("async_command")
|
||||
if not self._async_get_next:
|
||||
_LOGGER.warning("Paging async iterator protocol is not available for %s",
|
||||
self.__class__.__name__)
|
||||
|
||||
async def async_get(self, url):
|
||||
"""Get an arbitrary page.
|
||||
|
||||
This resets the iterator and then fully consumes it to return the
|
||||
specific page **only**.
|
||||
|
||||
:param str url: URL to arbitrary page results.
|
||||
"""
|
||||
self.reset()
|
||||
self.next_link = url
|
||||
return await self.async_advance_page()
|
||||
|
||||
async def async_advance_page(self):
|
||||
if self.next_link is None:
|
||||
raise StopAsyncIteration("End of paging")
|
||||
self._current_page_iter_index = 0
|
||||
self._response = await self._async_get_next(self.next_link)
|
||||
self._derserializer(self, self._response)
|
||||
return self.current_page
|
||||
|
||||
async def __anext__(self):
|
||||
"""Iterate through responses."""
|
||||
# Storing the list iterator might work out better, but there's no
|
||||
# guarantee that some code won't replace the list entirely with a copy,
|
||||
# invalidating an list iterator that might be saved between iterations.
|
||||
if self.current_page and self._current_page_iter_index < len(self.current_page):
|
||||
response = self.current_page[self._current_page_iter_index]
|
||||
self._current_page_iter_index += 1
|
||||
return response
|
||||
else:
|
||||
await self.async_advance_page()
|
||||
return await self.__anext__()
|
|
@ -31,36 +31,24 @@ try:
|
|||
except ImportError:
|
||||
import ConfigParser as configparser # type: ignore
|
||||
from ConfigParser import NoOptionError # type: ignore
|
||||
import platform
|
||||
|
||||
from typing import Dict, List, Any, Callable
|
||||
|
||||
import requests
|
||||
from typing import TYPE_CHECKING, Optional, Dict, List, Any, Callable # pylint: disable=unused-import
|
||||
|
||||
from .exceptions import raise_with_traceback
|
||||
from .pipeline import (
|
||||
ClientRetryPolicy,
|
||||
ClientRedirectPolicy,
|
||||
ClientProxies,
|
||||
ClientConnection)
|
||||
from .version import msrest_version
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .universal_http.requests import (
|
||||
RequestHTTPSenderConfiguration
|
||||
)
|
||||
from .pipeline.universal import (
|
||||
UserAgentPolicy,
|
||||
HTTPLogger,
|
||||
)
|
||||
|
||||
def default_session_configuration_callback(session, global_config, local_config, **kwargs):
|
||||
# type: (requests.Session, Configuration, Dict[str,str], str) -> Dict[str, str]
|
||||
"""Configuration callback if you need to change default session configuration.
|
||||
if TYPE_CHECKING:
|
||||
from .pipeline import AsyncPipeline
|
||||
|
||||
:param requests.Session session: The session.
|
||||
:param Configuration global_config: The global configuration.
|
||||
:param dict[str,str] local_config: The on-the-fly configuration passed on the call.
|
||||
:param dict[str,str] kwargs: The current computed values for session.request method.
|
||||
:return: Must return kwargs, to be passed to session.request. If None is return, initial kwargs will be used.
|
||||
:rtype: dict[str,str]
|
||||
"""
|
||||
return kwargs
|
||||
|
||||
|
||||
class Configuration(object):
|
||||
class Configuration(RequestHTTPSenderConfiguration):
|
||||
"""Client configuration.
|
||||
|
||||
:param str baseurl: REST API base URL.
|
||||
|
@ -68,48 +56,28 @@ class Configuration(object):
|
|||
"""
|
||||
|
||||
def __init__(self, base_url, filepath=None):
|
||||
# type: (str, str) -> None
|
||||
# type: (str, Optional[str]) -> None
|
||||
|
||||
super(Configuration, self).__init__(filepath)
|
||||
# Service
|
||||
self.base_url = base_url
|
||||
|
||||
# Communication configuration
|
||||
self.connection = ClientConnection()
|
||||
# User-Agent as a policy
|
||||
self.user_agent_policy = UserAgentPolicy()
|
||||
|
||||
# Headers (sent with every requests)
|
||||
self.headers = {} # type: Dict[str, str]
|
||||
# HTTP logger policy
|
||||
self.http_logger_policy = HTTPLogger()
|
||||
|
||||
# ProxyConfiguration
|
||||
self.proxies = ClientProxies()
|
||||
# The sync pipeline (will be replaced by the SDK default one, this instance if just for mypy)
|
||||
self.pipeline = Pipeline() # type: Pipeline
|
||||
|
||||
# Retry configuration
|
||||
self.retry_policy = ClientRetryPolicy()
|
||||
|
||||
# Redirect configuration
|
||||
self.redirect_policy = ClientRedirectPolicy()
|
||||
|
||||
# User-Agent Header
|
||||
self._user_agent = "python/{} ({}) requests/{} msrest/{}".format(
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
requests.__version__,
|
||||
msrest_version)
|
||||
|
||||
# Should we log HTTP requests/response?
|
||||
self.enable_http_logger = False
|
||||
|
||||
# Requests hooks. Must respect requests hook callback signature
|
||||
# Note that we will inject the following parameters:
|
||||
# - kwargs['msrest']['session'] with the current session
|
||||
self.hooks = [] # type: List[Callable[[requests.Response, str, str], None]]
|
||||
|
||||
self.session_configuration_callback = default_session_configuration_callback
|
||||
# The async pipeline
|
||||
# This is actual optional, since on 2.7 this will be None
|
||||
self.async_pipeline = None # type: Optional[AsyncPipeline]
|
||||
|
||||
# If set to True, ServiceClient will own the sessionn
|
||||
self.keep_alive = False
|
||||
|
||||
self._config = configparser.ConfigParser()
|
||||
self._config.optionxform = str # type: ignore
|
||||
|
||||
if filepath:
|
||||
self.load(filepath)
|
||||
|
||||
|
@ -117,7 +85,7 @@ class Configuration(object):
|
|||
def user_agent(self):
|
||||
# type: () -> str
|
||||
"""The current user agent value."""
|
||||
return self._user_agent
|
||||
return self.user_agent_policy.user_agent
|
||||
|
||||
def add_user_agent(self, value):
|
||||
# type: (str) -> None
|
||||
|
@ -125,94 +93,12 @@ class Configuration(object):
|
|||
|
||||
:param str value: value to add to user agent.
|
||||
"""
|
||||
self._user_agent = "{} {}".format(self._user_agent, value)
|
||||
self.user_agent_policy.add_user_agent(value)
|
||||
|
||||
def _clear_config(self):
|
||||
# type: () -> None
|
||||
"""Clearout config object in memory."""
|
||||
for section in self._config.sections():
|
||||
self._config.remove_section(section)
|
||||
@property
|
||||
def enable_http_logger(self):
|
||||
return self.http_logger_policy.enable_http_logger
|
||||
|
||||
def save(self, filepath):
|
||||
# type: (str) -> None
|
||||
"""Save current configuration to file.
|
||||
|
||||
:param str filepath: Path to file where settings will be saved.
|
||||
:raises: ValueError if supplied filepath cannot be written to.
|
||||
"""
|
||||
sections = [
|
||||
"Connection",
|
||||
"Proxies",
|
||||
"RetryPolicy",
|
||||
"RedirectPolicy"]
|
||||
for section in sections:
|
||||
self._config.add_section(section)
|
||||
|
||||
self._config.set("Connection", "base_url", self.base_url)
|
||||
self._config.set("Connection", "timeout", self.connection.timeout)
|
||||
self._config.set("Connection", "verify", self.connection.verify)
|
||||
self._config.set("Connection", "cert", self.connection.cert)
|
||||
|
||||
self._config.set("Proxies", "proxies", self.proxies.proxies)
|
||||
self._config.set("Proxies", "env_settings",
|
||||
self.proxies.use_env_settings)
|
||||
|
||||
self._config.set("RetryPolicy", "retries", str(self.retry_policy.retries))
|
||||
self._config.set("RetryPolicy", "backoff_factor",
|
||||
str(self.retry_policy.backoff_factor))
|
||||
self._config.set("RetryPolicy", "max_backoff",
|
||||
str(self.retry_policy.max_backoff))
|
||||
|
||||
self._config.set("RedirectPolicy", "allow", self.redirect_policy.allow)
|
||||
self._config.set("RedirectPolicy", "max_redirects",
|
||||
self.redirect_policy.max_redirects)
|
||||
try:
|
||||
with open(filepath, 'w') as configfile:
|
||||
self._config.write(configfile)
|
||||
except (KeyError, EnvironmentError):
|
||||
error = "Supplied config filepath invalid."
|
||||
raise_with_traceback(ValueError, error)
|
||||
finally:
|
||||
self._clear_config()
|
||||
|
||||
def load(self, filepath):
|
||||
# type: (str) -> None
|
||||
"""Load configuration from existing file.
|
||||
|
||||
:param str filepath: Path to existing config file.
|
||||
:raises: ValueError if supplied config file is invalid.
|
||||
"""
|
||||
try:
|
||||
self._config.read(filepath)
|
||||
|
||||
self.base_url = \
|
||||
self._config.get("Connection", "base_url")
|
||||
self.connection.timeout = \
|
||||
self._config.getint("Connection", "timeout")
|
||||
self.connection.verify = \
|
||||
self._config.getboolean("Connection", "verify")
|
||||
self.connection.cert = \
|
||||
self._config.get("Connection", "cert")
|
||||
|
||||
self.proxies.proxies = \
|
||||
eval(self._config.get("Proxies", "proxies"))
|
||||
self.proxies.use_env_settings = \
|
||||
self._config.getboolean("Proxies", "env_settings")
|
||||
|
||||
self.retry_policy.retries = \
|
||||
self._config.getint("RetryPolicy", "retries")
|
||||
self.retry_policy.backoff_factor = \
|
||||
self._config.getfloat("RetryPolicy", "backoff_factor")
|
||||
self.retry_policy.max_backoff = \
|
||||
self._config.getint("RetryPolicy", "max_backoff")
|
||||
|
||||
self.redirect_policy.allow = \
|
||||
self._config.getboolean("RedirectPolicy", "allow")
|
||||
self.redirect_policy.max_redirects = \
|
||||
self._config.getint("RedirectPolicy", "max_redirects")
|
||||
|
||||
except (ValueError, EnvironmentError, NoOptionError):
|
||||
error = "Supplied config file incompatible."
|
||||
raise_with_traceback(ValueError, error)
|
||||
finally:
|
||||
self._clear_config()
|
||||
@enable_http_logger.setter
|
||||
def enable_http_logger(self, value):
|
||||
self.http_logger_policy.enable_http_logger = value
|
||||
|
|
|
@ -37,7 +37,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def raise_with_traceback(exception, message="", *args, **kwargs):
|
||||
# type: (Callable, str, str, str) -> None
|
||||
# type: (Callable, str, Any, Any) -> None
|
||||
"""Raise exception with a specified traceback.
|
||||
|
||||
This MUST be called inside a "except" clause.
|
||||
|
@ -140,10 +140,13 @@ class HttpOperationError(ClientException):
|
|||
|
||||
def __init__(self, deserialize, response,
|
||||
resp_type=None, *args, **kwargs):
|
||||
# type: (Deserializer, requests.Response, Optional[str], str, str) -> None
|
||||
# type: (Deserializer, Any, Optional[str], str, str) -> None
|
||||
self.error = None
|
||||
self.message = self._DEFAULT_MESSAGE
|
||||
self.response = response
|
||||
if hasattr(response, 'internal_response'):
|
||||
self.response = response.internal_response
|
||||
else:
|
||||
self.response = response
|
||||
try:
|
||||
if resp_type:
|
||||
self.error = deserialize(resp_type, response)
|
||||
|
@ -168,9 +171,11 @@ class HttpOperationError(ClientException):
|
|||
try:
|
||||
response.raise_for_status()
|
||||
# Two possible raises here:
|
||||
# - Attribute error if response is not requests.RequestException. Do not catch.
|
||||
# - requests.RequestException. Catch base class IOError to avoid explicit import of requests here.
|
||||
except IOError as err:
|
||||
# - Attribute error if response is not ClientResponse. Do not catch.
|
||||
# - Any internal exception, take it.
|
||||
except AttributeError:
|
||||
raise
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
if not self.error:
|
||||
self.error = err
|
||||
|
||||
|
|
|
@ -28,16 +28,16 @@ import logging
|
|||
import re
|
||||
import types
|
||||
|
||||
from typing import Any, Union, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING # pylint: disable=unused-import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import requests
|
||||
from .universal_http import ClientRequest, ClientResponse # pylint: disable=unused-import
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log_request(_, request, *args, **kwargs):
|
||||
# type: (Any, requests.PreparedRequest, str, str) -> None
|
||||
def log_request(_, request, *_args, **_kwargs):
|
||||
# type: (Any, ClientRequest, str, str) -> None
|
||||
"""Log a client request.
|
||||
|
||||
:param _: Unused in current version (will be None)
|
||||
|
@ -61,12 +61,12 @@ def log_request(_, request, *args, **kwargs):
|
|||
_LOGGER.debug("File upload")
|
||||
else:
|
||||
_LOGGER.debug(str(request.body))
|
||||
except Exception as err:
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.debug("Failed to log request: %r", err)
|
||||
|
||||
|
||||
def log_response(_, request, response, *args, **kwargs):
|
||||
# type: (Any, requests.PreparedRequest, requests.Response, str, str) -> Optional[requests.Response]
|
||||
def log_response(_, _request, response, *_args, **kwargs):
|
||||
# type: (Any, ClientRequest, ClientResponse, str, Any) -> Optional[ClientResponse]
|
||||
"""Log a server response.
|
||||
|
||||
:param _: Unused in current version (will be None)
|
||||
|
@ -89,14 +89,17 @@ def log_response(_, request, response, *args, **kwargs):
|
|||
|
||||
if header and pattern.match(header):
|
||||
filename = header.partition('=')[2]
|
||||
_LOGGER.debug("File attachments: " + filename)
|
||||
_LOGGER.debug("File attachments: %s", filename)
|
||||
elif response.headers.get("content-type", "").endswith("octet-stream"):
|
||||
_LOGGER.debug("Body contains binary data.")
|
||||
elif response.headers.get("content-type", "").startswith("image"):
|
||||
_LOGGER.debug("Body contains image data.")
|
||||
else:
|
||||
_LOGGER.debug(str(response.content))
|
||||
if kwargs.get('stream', False):
|
||||
_LOGGER.debug("Body is streamable")
|
||||
else:
|
||||
_LOGGER.debug(response.text())
|
||||
return response
|
||||
except Exception as err:
|
||||
_LOGGER.debug("Failed to log response: " + repr(err))
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.debug("Failed to log response: %s", repr(err))
|
||||
return response
|
||||
|
|
|
@ -23,39 +23,51 @@
|
|||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import sys
|
||||
try:
|
||||
from collections.abc import Iterator
|
||||
xrange = range
|
||||
except ImportError:
|
||||
from collections import Iterator
|
||||
|
||||
from typing import Dict, Any, List, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Dict, Any, List, Callable, Optional, TYPE_CHECKING # pylint: disable=unused-import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import requests
|
||||
|
||||
from .serialization import Deserializer, Model
|
||||
from .serialization import Deserializer
|
||||
from .pipeline import ClientRawResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .universal_http import ClientResponse # pylint: disable=unused-import
|
||||
from .serialization import Model # pylint: disable=unused-import
|
||||
|
||||
class Paged(Iterator):
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
# Not executed on old Python, no syntax error
|
||||
from .async_paging import AsyncPagedMixin # type: ignore
|
||||
else:
|
||||
class AsyncPagedMixin(object): # type: ignore
|
||||
pass
|
||||
|
||||
class Paged(AsyncPagedMixin, Iterator):
|
||||
"""A container for paged REST responses.
|
||||
|
||||
:param requests.Response response: server response object.
|
||||
:param ClientResponse response: server response object.
|
||||
:param callable command: Function to retrieve the next page of items.
|
||||
:param dict classes: A dictionary of class dependencies for
|
||||
deserialization.
|
||||
:param dict raw_headers: A dict of raw headers to add if "raw" is called
|
||||
"""
|
||||
_validation = {} # type: Dict[str, Dict[str, Any]]
|
||||
_attribute_map = {} # type: Dict[str, Dict[str, Any]]
|
||||
|
||||
def __init__(self, command, classes, raw_headers=None):
|
||||
# type: (Callable[[str], requests.Response], Dict[str, Model], Dict[str, str]) -> None
|
||||
def __init__(self, command, classes, raw_headers=None, **kwargs):
|
||||
# type: (Callable[[str], ClientResponse], Dict[str, Model], Dict[str, str], Any) -> None
|
||||
super(Paged, self).__init__(**kwargs) # type: ignore
|
||||
# Sets next_link, current_page, and _current_page_iter_index.
|
||||
self.next_link = ""
|
||||
self._current_page_iter_index = 0
|
||||
self.reset()
|
||||
self._derserializer = Deserializer(classes)
|
||||
self._get_next = command
|
||||
self._response = None # type: Optional[requests.Response]
|
||||
self._response = None # type: Optional[ClientResponse]
|
||||
self._raw_headers = raw_headers
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -1,273 +0,0 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from typing import Dict, Any, Optional, Union, List, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import requests
|
||||
from urllib3 import Retry # Needs requests 2.16 at least to be safe
|
||||
|
||||
from .serialization import Deserializer, Model
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class ClientRequest(requests.Request):
|
||||
"""Wrapper for requests.Request object."""
|
||||
|
||||
def format_parameters(self, params):
|
||||
# type: (Dict[str, str]) -> None
|
||||
"""Format parameters into a valid query string.
|
||||
It's assumed all parameters have already been quoted as
|
||||
valid URL strings.
|
||||
|
||||
:param dict params: A dictionary of parameters.
|
||||
"""
|
||||
query = urlparse(self.url).query
|
||||
if query:
|
||||
self.url = self.url.partition('?')[0]
|
||||
existing_params = {
|
||||
p[0]: p[-1]
|
||||
for p in [p.partition('=') for p in query.split('&')]
|
||||
}
|
||||
params.update(existing_params)
|
||||
query_params = ["{}={}".format(k, v) for k, v in params.items()]
|
||||
query = '?' + '&'.join(query_params)
|
||||
self.url = self.url + query
|
||||
|
||||
def add_content(self, data):
|
||||
# type: (Optional[Union[Dict[str, Any], ET.Element]]) -> None
|
||||
"""Add a body to the request.
|
||||
|
||||
:param data: Request body data, can be a json serializable
|
||||
object (e.g. dictionary) or a generator (e.g. file data).
|
||||
"""
|
||||
if data is None:
|
||||
return
|
||||
|
||||
if isinstance(data, ET.Element):
|
||||
self.data = ET.tostring(data, encoding="utf8")
|
||||
self.headers['Content-Length'] = str(len(self.data))
|
||||
return
|
||||
|
||||
# By default, assume JSON
|
||||
try:
|
||||
self.data = json.dumps(data)
|
||||
self.headers['Content-Length'] = str(len(self.data))
|
||||
except TypeError:
|
||||
self.data = data
|
||||
|
||||
|
||||
class ClientRawResponse(object):
|
||||
"""Wrapper for response object.
|
||||
This allows for additional data to be gathereded from the response,
|
||||
for example deserialized headers.
|
||||
It also allows the raw response object to be passed back to the user.
|
||||
|
||||
:param output: Deserialized response object.
|
||||
:param response: Raw response object.
|
||||
"""
|
||||
|
||||
def __init__(self, output, response):
|
||||
# type: (Union[Model, List[Model]], Optional[requests.Response]) -> None
|
||||
self.response = response
|
||||
self.output = output
|
||||
self.headers = {} # type: Dict[str, Optional[Any]]
|
||||
self._deserialize = Deserializer()
|
||||
|
||||
def add_headers(self, header_dict):
|
||||
# type: (Dict[str, str]) -> None
|
||||
"""Deserialize a specific header.
|
||||
|
||||
:param dict header_dict: A dictionary containing the name of the
|
||||
header and the type to deserialize to.
|
||||
"""
|
||||
if not self.response:
|
||||
return
|
||||
for name, data_type in header_dict.items():
|
||||
value = self.response.headers.get(name)
|
||||
value = self._deserialize(data_type, value)
|
||||
self.headers[name] = value
|
||||
|
||||
|
||||
class ClientRetryPolicy(object):
|
||||
"""Retry configuration settings.
|
||||
Container for retry policy object.
|
||||
"""
|
||||
|
||||
safe_codes = [i for i in range(500) if i != 408] + [501, 505]
|
||||
|
||||
def __init__(self):
|
||||
self.policy = Retry()
|
||||
self.policy.total = 3
|
||||
self.policy.connect = 3
|
||||
self.policy.read = 3
|
||||
self.policy.backoff_factor = 0.8
|
||||
self.policy.BACKOFF_MAX = 90
|
||||
|
||||
retry_codes = [i for i in range(999) if i not in self.safe_codes]
|
||||
self.policy.status_forcelist = retry_codes
|
||||
self.policy.method_whitelist = ['HEAD', 'TRACE', 'GET', 'PUT',
|
||||
'OPTIONS', 'DELETE', 'POST', 'PATCH']
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Retry
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = ("Configuring retry: max_retries=%r, "
|
||||
"backoff_factor=%r, max_backoff=%r")
|
||||
_LOGGER.debug(
|
||||
debug, self.retries, self.backoff_factor, self.max_backoff)
|
||||
return self.policy
|
||||
|
||||
@property
|
||||
def retries(self):
|
||||
# type: () -> int
|
||||
"""Total number of allowed retries."""
|
||||
return self.policy.total
|
||||
|
||||
@retries.setter
|
||||
def retries(self, value):
|
||||
# type: (int) -> None
|
||||
self.policy.total = value
|
||||
self.policy.connect = value
|
||||
self.policy.read = value
|
||||
|
||||
@property
|
||||
def backoff_factor(self):
|
||||
# type: () -> Union[int, float]
|
||||
"""Factor by which back-off delay is incementally increased."""
|
||||
return self.policy.backoff_factor
|
||||
|
||||
@backoff_factor.setter
|
||||
def backoff_factor(self, value):
|
||||
# type: (Union[int, float]) -> None
|
||||
self.policy.backoff_factor = value
|
||||
|
||||
@property
|
||||
def max_backoff(self):
|
||||
# type: () -> int
|
||||
"""Max retry back-off delay."""
|
||||
return self.policy.BACKOFF_MAX
|
||||
|
||||
@max_backoff.setter
|
||||
def max_backoff(self, value):
|
||||
# type: (int) -> None
|
||||
self.policy.BACKOFF_MAX = value
|
||||
|
||||
|
||||
class ClientRedirectPolicy(object):
|
||||
"""Redirect configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allow = True
|
||||
self.max_redirects = 30
|
||||
|
||||
def __bool__(self):
|
||||
# type: () -> bool
|
||||
"""Whether redirects are allowed."""
|
||||
return self.allow
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> int
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = "Configuring redirects: allow=%r, max=%r"
|
||||
_LOGGER.debug(debug, self.allow, self.max_redirects)
|
||||
return self.max_redirects
|
||||
|
||||
def check_redirect(self, resp, request):
|
||||
# type: (requests.Response, requests.PreparedRequest) -> bool
|
||||
"""Whether redirect policy should be applied based on status code."""
|
||||
if resp.status_code in (301, 302) and \
|
||||
request.method not in ['GET', 'HEAD']:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ClientProxies(object):
|
||||
"""Proxy configuration settings.
|
||||
Proxies can also be configured using HTTP_PROXY and HTTPS_PROXY
|
||||
environment variables, in which case set use_env_settings to True.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.proxies = {}
|
||||
self.use_env_settings = True
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Dict[str, str]
|
||||
"""Return configuration to be applied to connection."""
|
||||
proxy_string = "\n".join(
|
||||
[" {}: {}".format(k, v) for k, v in self.proxies.items()])
|
||||
|
||||
_LOGGER.debug("Configuring proxies: %r", proxy_string)
|
||||
debug = "Evaluate proxies against ENV settings: %r"
|
||||
_LOGGER.debug(debug, self.use_env_settings)
|
||||
return self.proxies
|
||||
|
||||
def add(self, protocol, proxy_url):
|
||||
# type: (str, str) -> None
|
||||
"""Add proxy.
|
||||
|
||||
:param str protocol: Protocol for which proxy is to be applied. Can
|
||||
be 'http', 'https', etc. Can also include host.
|
||||
:param str proxy_url: The proxy URL. Where basic auth is required,
|
||||
use the format: http://user:password@host
|
||||
"""
|
||||
self.proxies[protocol] = proxy_url
|
||||
|
||||
|
||||
class ClientConnection(object):
|
||||
"""Request connection configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.timeout = 100
|
||||
self.verify = True
|
||||
self.cert = None
|
||||
self.data_block_size = 4096
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Dict[str, Union[str, int]]
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = "Configuring request: timeout=%r, verify=%r, cert=%r"
|
||||
_LOGGER.debug(debug, self.timeout, self.verify, self.cert)
|
||||
return {'timeout': self.timeout,
|
||||
'verify': self.verify,
|
||||
'cert': self.cert}
|
|
@ -0,0 +1,328 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
|
||||
import abc
|
||||
try:
|
||||
import configparser
|
||||
from configparser import NoOptionError
|
||||
except ImportError:
|
||||
import ConfigParser as configparser # type: ignore
|
||||
from ConfigParser import NoOptionError # type: ignore
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, Tuple, Callable, Iterator # pylint: disable=unused-import
|
||||
|
||||
HTTPResponseType = TypeVar("HTTPResponseType")
|
||||
HTTPRequestType = TypeVar("HTTPRequestType")
|
||||
|
||||
# This file is NOT using any "requests" HTTP implementation
|
||||
# However, the CaseInsensitiveDict is handy.
|
||||
# If one day we reach the point where "requests" can be skip totally,
|
||||
# might provide our own implementation
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
|
||||
from ..exceptions import ClientRequestError, raise_with_traceback
|
||||
from ..universal_http import ClientResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..serialization import Model # pylint: disable=unused-import
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
ABC = abc.ABC
|
||||
except AttributeError: # Python 2.7, abc exists, but not ABC
|
||||
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) # type: ignore
|
||||
|
||||
try:
|
||||
from contextlib import AbstractContextManager # type: ignore
|
||||
except ImportError: # Python <= 3.5
|
||||
class AbstractContextManager(object): # type: ignore
|
||||
def __enter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
class HTTPPolicy(ABC, Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""An http policy ABC.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.next = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, request, **kwargs):
|
||||
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
|
||||
"""Mutate the request.
|
||||
|
||||
Context content is dependent of the HTTPSender.
|
||||
"""
|
||||
pass
|
||||
|
||||
class SansIOHTTPPolicy(Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""Represents a sans I/O policy.
|
||||
|
||||
This policy can act before the I/O, and after the I/O.
|
||||
Use this policy if the actual I/O in the middle is an implementation
|
||||
detail.
|
||||
|
||||
Context is not available, since it's implementation dependent.
|
||||
if a policy needs a context of the Sender, it can't be universal.
|
||||
|
||||
Example: setting a UserAgent does not need to be tight to
|
||||
sync or async implementation or specific HTTP lib
|
||||
"""
|
||||
def on_request(self, request, **kwargs):
|
||||
# type: (Request[HTTPRequestType], Any) -> None
|
||||
"""Is executed before sending the request to next policy.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_response(self, request, response, **kwargs):
|
||||
# type: (Request[HTTPRequestType], Response[HTTPRequestType, HTTPResponseType], Any) -> None
|
||||
"""Is executed after the request comes back from the policy.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_exception(self, request, **kwargs):
|
||||
# type: (Request[HTTPRequestType], Any) -> bool
|
||||
"""Is executed if an exception comes back fron the following
|
||||
policy.
|
||||
|
||||
Return True if the exception has been handled and should not
|
||||
be forwarded to the caller.
|
||||
|
||||
This method is executed inside the exception handler.
|
||||
To get the exception, raise and catch it:
|
||||
|
||||
try:
|
||||
raise
|
||||
except MyError:
|
||||
do_something()
|
||||
|
||||
or use
|
||||
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
"""
|
||||
return False
|
||||
|
||||
class _SansIOHTTPPolicyRunner(HTTPPolicy, Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""Sync implementation of the SansIO policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy):
|
||||
# type: (SansIOHTTPPolicy) -> None
|
||||
super(_SansIOHTTPPolicyRunner, self).__init__()
|
||||
self._policy = policy
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
|
||||
self._policy.on_request(request, **kwargs)
|
||||
try:
|
||||
response = self.next.send(request, **kwargs)
|
||||
except Exception:
|
||||
if not self._policy.on_exception(request, **kwargs):
|
||||
raise
|
||||
else:
|
||||
self._policy.on_response(request, response, **kwargs)
|
||||
return response
|
||||
|
||||
class Pipeline(AbstractContextManager, Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""A pipeline implementation.
|
||||
|
||||
This is implemented as a context manager, that will activate the context
|
||||
of the HTTP sender.
|
||||
"""
|
||||
|
||||
def __init__(self, policies=None, sender=None):
|
||||
# type: (List[Union[HTTPPolicy, SansIOHTTPPolicy]], HTTPSender) -> None
|
||||
self._impl_policies = [] # type: List[HTTPPolicy]
|
||||
if not sender:
|
||||
# Import default only if nothing is provided
|
||||
from .requests import PipelineRequestsHTTPSender
|
||||
self._sender = cast(HTTPSender, PipelineRequestsHTTPSender())
|
||||
else:
|
||||
self._sender = sender
|
||||
for policy in (policies or []):
|
||||
if isinstance(policy, SansIOHTTPPolicy):
|
||||
self._impl_policies.append(_SansIOHTTPPolicyRunner(policy))
|
||||
else:
|
||||
self._impl_policies.append(policy)
|
||||
for index in range(len(self._impl_policies)-1):
|
||||
self._impl_policies[index].next = self._impl_policies[index+1]
|
||||
if self._impl_policies:
|
||||
self._impl_policies[-1].next = self._sender
|
||||
|
||||
def __enter__(self):
|
||||
# type: () -> Pipeline
|
||||
self._sender.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
self._sender.__exit__(*exc_details)
|
||||
|
||||
def run(self, request, **kwargs):
|
||||
# type: (HTTPRequestType, Any) -> Response
|
||||
context = self._sender.build_context()
|
||||
pipeline_request = Request(request, context) # type: Request[HTTPRequestType]
|
||||
first_node = self._impl_policies[0] if self._impl_policies else self._sender
|
||||
return first_node.send(pipeline_request, **kwargs) # type: ignore
|
||||
|
||||
class HTTPSender(AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""An http sender ABC.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, request, **config):
|
||||
# type: (Request[HTTPRequestType], Any) -> Response[HTTPRequestType, HTTPResponseType]
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_context(self):
|
||||
# type: () -> Any
|
||||
"""Allow the sender to build a context that will be passed
|
||||
across the pipeline with the request.
|
||||
|
||||
Return type has no constraints. Implementation is not
|
||||
required and None by default.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class Request(Generic[HTTPRequestType]):
|
||||
"""Represents a HTTP request in a Pipeline.
|
||||
|
||||
URL can be given without query parameters, to be added later using "format_parameters".
|
||||
|
||||
Instance can be created without data, to be added later using "add_content"
|
||||
|
||||
Instance can be created without files, to be added later using "add_formdata"
|
||||
|
||||
:param str method: HTTP method (GET, HEAD, etc.)
|
||||
:param str url: At least complete scheme/host/path
|
||||
:param dict[str,str] headers: HTTP headers
|
||||
:param files: Files list.
|
||||
:param data: Body to be sent.
|
||||
:type data: bytes or str.
|
||||
"""
|
||||
def __init__(self, http_request, context=None):
|
||||
# type: (HTTPRequestType, Optional[Any]) -> None
|
||||
self.http_request = http_request
|
||||
self.context = context
|
||||
|
||||
|
||||
class Response(Generic[HTTPRequestType, HTTPResponseType]):
|
||||
"""A pipeline response object.
|
||||
|
||||
The Response interface exposes an HTTP response object as it returns through the pipeline of Policy objects.
|
||||
This ensures that Policy objects have access to the HTTP response.
|
||||
|
||||
This also have a "context" dictionnary where policy can put additional fields.
|
||||
Policy SHOULD update the "context" dictionary with additional post-processed field if they create them.
|
||||
However, nothing prevents a policy to actually sub-class this class a return it instead of the initial instance.
|
||||
"""
|
||||
def __init__(self, request, http_response, context=None):
|
||||
# type: (Request[HTTPRequestType], HTTPResponseType, Optional[Dict[str, Any]]) -> None
|
||||
self.request = request
|
||||
self.http_response = http_response
|
||||
self.context = context or {}
|
||||
|
||||
|
||||
|
||||
# ClientRawResponse is in Pipeline for compat, but technically there is nothing Pipeline here, this is deserialization
|
||||
|
||||
class ClientRawResponse(object):
|
||||
"""Wrapper for response object.
|
||||
This allows for additional data to be gathereded from the response,
|
||||
for example deserialized headers.
|
||||
It also allows the raw response object to be passed back to the user.
|
||||
|
||||
:param output: Deserialized response object.
|
||||
:param response: Raw response object.
|
||||
"""
|
||||
|
||||
def __init__(self, output, response):
|
||||
# type: (Union[Model, List[Model]], Optional[Union[Response, ClientResponse]]) -> None
|
||||
from ..serialization import Deserializer
|
||||
|
||||
if isinstance(response, Response):
|
||||
# If pipeline response, remove that layer
|
||||
response = response.http_response
|
||||
|
||||
if isinstance(response, ClientResponse):
|
||||
# If universal driver, remove that layer
|
||||
self.response = response.internal_response
|
||||
else:
|
||||
self.response = response
|
||||
self.output = output
|
||||
self.headers = {} # type: Dict[str, Optional[Any]]
|
||||
self._deserialize = Deserializer()
|
||||
|
||||
def add_headers(self, header_dict):
|
||||
# type: (Dict[str, str]) -> None
|
||||
"""Deserialize a specific header.
|
||||
|
||||
:param dict header_dict: A dictionary containing the name of the
|
||||
header and the type to deserialize to.
|
||||
"""
|
||||
if not self.response:
|
||||
return
|
||||
for name, data_type in header_dict.items():
|
||||
value = self.response.headers.get(name)
|
||||
value = self._deserialize(data_type, value)
|
||||
self.headers[name] = value
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Request',
|
||||
'Response',
|
||||
'Pipeline',
|
||||
'HTTPPolicy',
|
||||
'SansIOHTTPPolicy',
|
||||
'HTTPSender',
|
||||
# backward compat
|
||||
'ClientRawResponse',
|
||||
]
|
||||
|
||||
try:
|
||||
from .async_abc import AsyncPipeline, AsyncHTTPPolicy, AsyncHTTPSender # pylint: disable=unused-import
|
||||
from .async_abc import __all__ as _async_all
|
||||
__all__ += _async_all
|
||||
except SyntaxError: # Python 2
|
||||
pass
|
|
@ -0,0 +1,62 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..universal_http.aiohttp import AioHTTPSender as _AioHTTPSenderDriver
|
||||
from . import AsyncHTTPSender, Request, Response
|
||||
|
||||
# Matching requests, because why not?
|
||||
CONTENT_CHUNK_SIZE = 10 * 1024
|
||||
|
||||
class AioHTTPSender(AsyncHTTPSender):
|
||||
"""AioHttp HTTP sender implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, driver: Optional[_AioHTTPSenderDriver] = None, *, loop=None) -> None:
|
||||
self.driver = driver or _AioHTTPSenderDriver(loop=loop)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.driver.__aenter__()
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
await self.driver.__aexit__(*exc_details)
|
||||
|
||||
def build_context(self) -> Any:
|
||||
"""Allow the sender to build a context that will be passed
|
||||
across the pipeline with the request.
|
||||
|
||||
Return type has no constraints. Implementation is not
|
||||
required and None by default.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def send(self, request: Request, **config: Any) -> Response:
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
return Response(
|
||||
request,
|
||||
await self.driver.send(request.http_request)
|
||||
)
|
|
@ -0,0 +1,165 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
from typing import Any, List, Union, Callable, AsyncIterator, Optional, Generic, TypeVar
|
||||
|
||||
from . import Request, Response, Pipeline, SansIOHTTPPolicy
|
||||
|
||||
|
||||
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType")
|
||||
HTTPRequestType = TypeVar("HTTPRequestType")
|
||||
|
||||
try:
|
||||
from contextlib import AbstractAsyncContextManager # type: ignore
|
||||
except ImportError: # Python <= 3.7
|
||||
class AbstractAsyncContextManager(object): # type: ignore
|
||||
async def __aenter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
class AsyncHTTPPolicy(abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
|
||||
"""An http policy ABC.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
# next will be set once in the pipeline
|
||||
self.next = None # type: Optional[Union[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType], AsyncHTTPSender[HTTPRequestType, AsyncHTTPResponseType]]]
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
|
||||
"""Mutate the request.
|
||||
|
||||
Context content is dependent of the HTTPSender.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _SansIOAsyncHTTPPolicyRunner(AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]):
|
||||
"""Async implementation of the SansIO policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: SansIOHTTPPolicy) -> None:
|
||||
super(_SansIOAsyncHTTPPolicyRunner, self).__init__()
|
||||
self._policy = policy
|
||||
|
||||
async def send(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
|
||||
self._policy.on_request(request, **kwargs)
|
||||
try:
|
||||
response = await self.next.send(request, **kwargs) # type: ignore
|
||||
except Exception:
|
||||
if not self._policy.on_exception(request, **kwargs):
|
||||
raise
|
||||
else:
|
||||
self._policy.on_response(request, response, **kwargs)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncHTTPSender(AbstractAsyncContextManager, abc.ABC, Generic[HTTPRequestType, AsyncHTTPResponseType]):
|
||||
"""An http sender ABC.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, request: Request[HTTPRequestType], **config: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_context(self) -> Any:
|
||||
"""Allow the sender to build a context that will be passed
|
||||
across the pipeline with the request.
|
||||
|
||||
Return type has no constraints. Implementation is not
|
||||
required and None by default.
|
||||
"""
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
raise TypeError("Use async with instead")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# __exit__ should exist in pair with __enter__ but never executed
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
class AsyncPipeline(AbstractAsyncContextManager, Generic[HTTPRequestType, AsyncHTTPResponseType]):
|
||||
"""A pipeline implementation.
|
||||
|
||||
This is implemented as a context manager, that will activate the context
|
||||
of the HTTP sender.
|
||||
"""
|
||||
|
||||
def __init__(self, policies: List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]] = None, sender: Optional[AsyncHTTPSender[HTTPRequestType, AsyncHTTPResponseType]] = None) -> None:
|
||||
self._impl_policies = [] # type: List[AsyncHTTPPolicy[HTTPRequestType, AsyncHTTPResponseType]]
|
||||
if sender:
|
||||
self._sender = sender
|
||||
else:
|
||||
# Import default only if nothing is provided
|
||||
from .aiohttp import AioHTTPSender
|
||||
self._sender = AioHTTPSender()
|
||||
|
||||
for policy in (policies or []):
|
||||
if isinstance(policy, SansIOHTTPPolicy):
|
||||
self._impl_policies.append(_SansIOAsyncHTTPPolicyRunner(policy))
|
||||
else:
|
||||
self._impl_policies.append(policy)
|
||||
for index in range(len(self._impl_policies)-1):
|
||||
self._impl_policies[index].next = self._impl_policies[index+1]
|
||||
if self._impl_policies:
|
||||
self._impl_policies[-1].next = self._sender
|
||||
|
||||
def __enter__(self):
|
||||
raise TypeError("Use 'async with' instead")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# __exit__ should exist in pair with __enter__ but never executed
|
||||
pass # pragma: no cover
|
||||
|
||||
async def __aenter__(self) -> 'AsyncPipeline':
|
||||
await self._sender.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
await self._sender.__aexit__(*exc_details)
|
||||
|
||||
async def run(self, request: Request, **kwargs: Any) -> Response[HTTPRequestType, AsyncHTTPResponseType]:
|
||||
context = self._sender.build_context()
|
||||
pipeline_request = Request(request, context)
|
||||
first_node = self._impl_policies[0] if self._impl_policies else self._sender
|
||||
return await first_node.send(pipeline_request, **kwargs) # type: ignore
|
||||
|
||||
__all__ = [
|
||||
'AsyncHTTPPolicy',
|
||||
'AsyncHTTPSender',
|
||||
'AsyncPipeline',
|
||||
]
|
|
@ -0,0 +1,129 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, AsyncIterator as AsyncIteratorType
|
||||
|
||||
from oauthlib import oauth2
|
||||
import requests
|
||||
from requests.models import CONTENT_CHUNK_SIZE
|
||||
|
||||
from ..exceptions import (
|
||||
TokenExpiredError,
|
||||
ClientRequestError,
|
||||
raise_with_traceback
|
||||
)
|
||||
from ..universal_http.async_requests import AsyncBasicRequestsHTTPSender
|
||||
from . import AsyncHTTPSender, AsyncHTTPPolicy, Response, Request
|
||||
from .requests import RequestsContext
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncPipelineRequestsHTTPSender(AsyncHTTPSender):
|
||||
"""Implements a basic Pipeline, that supports universal HTTP lib "requests" driver.
|
||||
"""
|
||||
|
||||
def __init__(self, universal_http_requests_driver: Optional[AsyncBasicRequestsHTTPSender]=None) -> None:
|
||||
self.driver = universal_http_requests_driver or AsyncBasicRequestsHTTPSender()
|
||||
|
||||
async def __aenter__(self) -> 'AsyncPipelineRequestsHTTPSender':
|
||||
await self.driver.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
await self.driver.__aexit__(*exc_details)
|
||||
|
||||
async def close(self):
|
||||
await self.__aexit__()
|
||||
|
||||
def build_context(self):
|
||||
# type: () -> RequestsContext
|
||||
return RequestsContext(
|
||||
session=self.driver.session,
|
||||
)
|
||||
|
||||
async def send(self, request: Request, **kwargs) -> Response:
|
||||
"""Send request object according to configuration.
|
||||
|
||||
:param Request request: The request object to be sent.
|
||||
"""
|
||||
if request.context is None: # Should not happen, but make mypy happy and does not hurt
|
||||
request.context = self.build_context()
|
||||
|
||||
if request.context.session is not self.driver.session:
|
||||
kwargs['session'] = request.context.session
|
||||
|
||||
return Response(
|
||||
request,
|
||||
await self.driver.send(request.http_request, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
class AsyncRequestsCredentialsPolicy(AsyncHTTPPolicy):
|
||||
"""Implementation of request-oauthlib except and retry logic.
|
||||
"""
|
||||
def __init__(self, credentials):
|
||||
super(AsyncRequestsCredentialsPolicy, self).__init__()
|
||||
self._creds = credentials
|
||||
|
||||
async def send(self, request, **kwargs):
|
||||
session = request.context.session
|
||||
try:
|
||||
self._creds.signed_session(session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
request.context.session = session = self._creds.signed_session()
|
||||
|
||||
try:
|
||||
try:
|
||||
return await self.next.send(request, **kwargs)
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
error = "Token expired or is invalid. Attempting to refresh."
|
||||
_LOGGER.warning(error)
|
||||
|
||||
try:
|
||||
try:
|
||||
self._creds.refresh_session(session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
request.context.session = session = self._creds.refresh_session()
|
||||
|
||||
return await self.next.send(request, **kwargs)
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
msg = "Token expired or is invalid."
|
||||
raise_with_traceback(TokenExpiredError, msg, err)
|
||||
|
||||
except (requests.RequestException,
|
||||
oauth2.rfc6749.errors.OAuth2Error) as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
"""
|
||||
This module is the requests implementation of Pipeline ABC
|
||||
"""
|
||||
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Callable, Iterator, Any, Union, Dict, Optional # pylint: disable=unused-import
|
||||
import warnings
|
||||
|
||||
from oauthlib import oauth2
|
||||
import requests
|
||||
from requests.models import CONTENT_CHUNK_SIZE
|
||||
|
||||
from urllib3 import Retry # Needs requests 2.16 at least to be safe
|
||||
|
||||
from ..exceptions import (
|
||||
TokenExpiredError,
|
||||
ClientRequestError,
|
||||
raise_with_traceback
|
||||
)
|
||||
from ..universal_http import ClientRequest
|
||||
from ..universal_http.requests import BasicRequestsHTTPSender
|
||||
from . import HTTPSender, HTTPPolicy, Response, Request
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RequestsCredentialsPolicy(HTTPPolicy):
|
||||
"""Implementation of request-oauthlib except and retry logic.
|
||||
"""
|
||||
def __init__(self, credentials):
|
||||
super(RequestsCredentialsPolicy, self).__init__()
|
||||
self._creds = credentials
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
session = request.context.session
|
||||
try:
|
||||
self._creds.signed_session(session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
request.context.session = session = self._creds.signed_session()
|
||||
|
||||
try:
|
||||
try:
|
||||
return self.next.send(request, **kwargs)
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
error = "Token expired or is invalid. Attempting to refresh."
|
||||
_LOGGER.warning(error)
|
||||
|
||||
try:
|
||||
try:
|
||||
self._creds.refresh_session(session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
request.context.session = session = self._creds.refresh_session()
|
||||
|
||||
return self.next.send(request, **kwargs)
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
msg = "Token expired or is invalid."
|
||||
raise_with_traceback(TokenExpiredError, msg, err)
|
||||
|
||||
except (requests.RequestException,
|
||||
oauth2.rfc6749.errors.OAuth2Error) as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
|
||||
class RequestsPatchSession(HTTPPolicy):
|
||||
"""Implements request level configuration
|
||||
that are actually to be done at the session level.
|
||||
|
||||
This is highly deprecated, and is totally legacy.
|
||||
The pipeline structure allows way better design for this.
|
||||
"""
|
||||
_protocols = ['http://', 'https://']
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
"""Patch the current session with Request level operation config.
|
||||
|
||||
This is deprecated, we shouldn't patch the session with
|
||||
arguments at the Request, and "config" should be used.
|
||||
"""
|
||||
session = request.context.session
|
||||
|
||||
old_max_redirects = None
|
||||
if 'max_redirects' in kwargs:
|
||||
warnings.warn("max_redirects in operation kwargs is deprecated, use config.redirect_policy instead",
|
||||
DeprecationWarning)
|
||||
old_max_redirects = session.max_redirects
|
||||
session.max_redirects = int(kwargs['max_redirects'])
|
||||
|
||||
old_trust_env = None
|
||||
if 'use_env_proxies' in kwargs:
|
||||
warnings.warn("use_env_proxies in operation kwargs is deprecated, use config.proxies instead",
|
||||
DeprecationWarning)
|
||||
old_trust_env = session.trust_env
|
||||
session.trust_env = bool(kwargs['use_env_proxies'])
|
||||
|
||||
old_retries = {}
|
||||
if 'retries' in kwargs:
|
||||
warnings.warn("retries in operation kwargs is deprecated, use config.retry_policy instead",
|
||||
DeprecationWarning)
|
||||
max_retries = kwargs['retries']
|
||||
for protocol in self._protocols:
|
||||
old_retries[protocol] = session.adapters[protocol].max_retries
|
||||
session.adapters[protocol].max_retries = max_retries
|
||||
|
||||
try:
|
||||
return self.next.send(request, **kwargs)
|
||||
finally:
|
||||
if old_max_redirects:
|
||||
session.max_redirects = old_max_redirects
|
||||
|
||||
if old_trust_env:
|
||||
session.trust_env = old_trust_env
|
||||
|
||||
if old_retries:
|
||||
for protocol in self._protocols:
|
||||
session.adapters[protocol].max_retries = old_retries[protocol]
|
||||
|
||||
class RequestsContext(object):
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
|
||||
class PipelineRequestsHTTPSender(HTTPSender):
|
||||
"""Implements a basic Pipeline, that supports universal HTTP lib "requests" driver.
|
||||
"""
|
||||
|
||||
def __init__(self, universal_http_requests_driver=None):
|
||||
# type: (Optional[BasicRequestsHTTPSender]) -> None
|
||||
self.driver = universal_http_requests_driver or BasicRequestsHTTPSender()
|
||||
|
||||
def __enter__(self):
|
||||
# type: () -> PipelineRequestsHTTPSender
|
||||
self.driver.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
self.driver.__exit__(*exc_details)
|
||||
|
||||
def close(self):
|
||||
self.__exit__()
|
||||
|
||||
def build_context(self):
|
||||
# type: () -> RequestsContext
|
||||
return RequestsContext(
|
||||
session=self.driver.session,
|
||||
)
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
# type: (Request[ClientRequest], Any) -> Response
|
||||
"""Send request object according to configuration.
|
||||
|
||||
:param Request request: The request object to be sent.
|
||||
"""
|
||||
if request.context is None: # Should not happen, but make mypy happy and does not hurt
|
||||
request.context = self.build_context()
|
||||
|
||||
if request.context.session is not self.driver.session:
|
||||
kwargs['session'] = request.context.session
|
||||
|
||||
return Response(
|
||||
request,
|
||||
self.driver.send(request.http_request, **kwargs)
|
||||
)
|
|
@ -0,0 +1,239 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
"""
|
||||
This module represents universal policy that works whatever the HTTPSender implementation
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
import platform
|
||||
|
||||
from typing import Mapping, Any, Optional, AnyStr, Union, IO, cast, TYPE_CHECKING # pylint: disable=unused-import
|
||||
|
||||
from ..version import msrest_version as _msrest_version
|
||||
from . import SansIOHTTPPolicy
|
||||
from ..exceptions import DeserializationError, raise_with_traceback
|
||||
from ..http_logger import log_request, log_response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Request, Response # pylint: disable=unused-import
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HeadersPolicy(SansIOHTTPPolicy):
|
||||
"""A simple policy that sends the given headers
|
||||
with the request.
|
||||
|
||||
This overwrite any headers already defined in the request.
|
||||
"""
|
||||
def __init__(self, headers):
|
||||
# type: (Mapping[str, str]) -> None
|
||||
self.headers = headers
|
||||
|
||||
def on_request(self, request, **kwargs):
|
||||
# type: (Request, Any) -> None
|
||||
http_request = request.http_request
|
||||
http_request.headers.update(self.headers)
|
||||
|
||||
class UserAgentPolicy(SansIOHTTPPolicy):
|
||||
_USERAGENT = "User-Agent"
|
||||
_ENV_ADDITIONAL_USER_AGENT = 'AZURE_HTTP_USER_AGENT'
|
||||
|
||||
def __init__(self, user_agent=None, overwrite=False):
|
||||
# type: (Optional[str], bool) -> None
|
||||
self._overwrite = overwrite
|
||||
if user_agent is None:
|
||||
self._user_agent = "python/{} ({}) msrest/{}".format(
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
_msrest_version
|
||||
)
|
||||
else:
|
||||
self._user_agent = user_agent
|
||||
|
||||
# Whatever you gave me a header explicitly or not,
|
||||
# if the env variable is set, add to it.
|
||||
add_user_agent_header = os.environ.get(self._ENV_ADDITIONAL_USER_AGENT, None)
|
||||
if add_user_agent_header is not None:
|
||||
self.add_user_agent(add_user_agent_header)
|
||||
|
||||
@property
|
||||
def user_agent(self):
|
||||
# type: () -> str
|
||||
"""The current user agent value."""
|
||||
return self._user_agent
|
||||
|
||||
def add_user_agent(self, value):
|
||||
# type: (str) -> None
|
||||
"""Add value to current user agent with a space.
|
||||
|
||||
:param str value: value to add to user agent.
|
||||
"""
|
||||
self._user_agent = "{} {}".format(self._user_agent, value)
|
||||
|
||||
def on_request(self, request, **kwargs):
|
||||
# type: (Request, Any) -> None
|
||||
http_request = request.http_request
|
||||
if self._overwrite or self._USERAGENT not in http_request.headers:
|
||||
http_request.headers[self._USERAGENT] = self._user_agent
|
||||
|
||||
class HTTPLogger(SansIOHTTPPolicy):
|
||||
"""A policy that logs HTTP request and response to the DEBUG logger.
|
||||
|
||||
This accepts both global configuration, and kwargs request level with "enable_http_logger"
|
||||
"""
|
||||
def __init__(self, enable_http_logger = False):
|
||||
self.enable_http_logger = enable_http_logger
|
||||
|
||||
def on_request(self, request, **kwargs):
|
||||
# type: (Request, Any) -> None
|
||||
http_request = request.http_request
|
||||
if kwargs.get("enable_http_logger", self.enable_http_logger):
|
||||
log_request(None, http_request)
|
||||
|
||||
def on_response(self, request, response, **kwargs):
|
||||
# type: (Request, Response, Any) -> None
|
||||
http_request = request.http_request
|
||||
if kwargs.get("enable_http_logger", self.enable_http_logger):
|
||||
log_response(None, http_request, response.http_response, result=response)
|
||||
|
||||
|
||||
class RawDeserializer(SansIOHTTPPolicy):
|
||||
|
||||
JSON_MIMETYPES = [
|
||||
'application/json',
|
||||
'text/json' # Because we're open minded people...
|
||||
]
|
||||
# Name used in context
|
||||
CONTEXT_NAME = "deserialized_data"
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_text(cls, data, content_type=None):
|
||||
# type: (Optional[Union[AnyStr, IO]], Optional[str]) -> Any
|
||||
"""Decode data according to content-type.
|
||||
|
||||
Accept a stream of data as well, but will be load at once in memory for now.
|
||||
|
||||
If no content-type, will return the string version (not bytes, not stream)
|
||||
|
||||
:param data: Input, could be bytes or stream (will be decoded with UTF8) or text
|
||||
:type data: str or bytes or IO
|
||||
:param str content_type: The content type.
|
||||
"""
|
||||
if hasattr(data, 'read'):
|
||||
# Assume a stream
|
||||
data = cast(IO, data).read()
|
||||
|
||||
if isinstance(data, bytes):
|
||||
data_as_str = data.decode(encoding='utf-8-sig')
|
||||
else:
|
||||
# Explain to mypy the correct type.
|
||||
data_as_str = cast(str, data)
|
||||
|
||||
if content_type is None:
|
||||
return data
|
||||
|
||||
if content_type in cls.JSON_MIMETYPES:
|
||||
try:
|
||||
return json.loads(data_as_str)
|
||||
except ValueError as err:
|
||||
raise DeserializationError("JSON is invalid: {}".format(err), err)
|
||||
elif "xml" in (content_type or []):
|
||||
try:
|
||||
return ET.fromstring(data_as_str)
|
||||
except ET.ParseError:
|
||||
# It might be because the server has an issue, and returned JSON with
|
||||
# content-type XML....
|
||||
# So let's try a JSON load, and if it's still broken
|
||||
# let's flow the initial exception
|
||||
def _json_attemp(data):
|
||||
try:
|
||||
return True, json.loads(data)
|
||||
except ValueError:
|
||||
return False, None # Don't care about this one
|
||||
success, json_result = _json_attemp(data)
|
||||
if success:
|
||||
return json_result
|
||||
# If i'm here, it's not JSON, it's not XML, let's scream
|
||||
# and raise the last context in this block (the XML exception)
|
||||
# The function hack is because Py2.7 messes up with exception
|
||||
# context otherwise.
|
||||
_LOGGER.critical("Wasn't XML not JSON, failing")
|
||||
raise_with_traceback(DeserializationError, "XML is invalid")
|
||||
raise DeserializationError("Cannot deserialize content-type: {}".format(content_type))
|
||||
|
||||
@classmethod
|
||||
def deserialize_from_http_generics(cls, body_bytes, headers):
|
||||
# type: (Optional[Union[AnyStr, IO]], Mapping) -> Any
|
||||
"""Deserialize from HTTP response.
|
||||
|
||||
Use bytes and headers to NOT use any requests/aiohttp or whatever
|
||||
specific implementation.
|
||||
Headers will tested for "content-type"
|
||||
"""
|
||||
# Try to use content-type from headers if available
|
||||
content_type = None
|
||||
if 'content-type' in headers:
|
||||
content_type = headers['content-type'].split(";")[0].strip().lower()
|
||||
# Ouch, this server did not declare what it sent...
|
||||
# Let's guess it's JSON...
|
||||
# Also, since Autorest was considering that an empty body was a valid JSON,
|
||||
# need that test as well....
|
||||
else:
|
||||
content_type = "application/json"
|
||||
|
||||
if body_bytes:
|
||||
return cls.deserialize_from_text(body_bytes, content_type)
|
||||
return None
|
||||
|
||||
def on_response(self, request, response, **kwargs):
|
||||
# type: (Request, Response, Any) -> None
|
||||
"""Extract data from the body of a REST response object.
|
||||
|
||||
This will load the entire payload in memory.
|
||||
|
||||
Will follow Content-Type to parse.
|
||||
We assume everything is UTF8 (BOM acceptable).
|
||||
|
||||
:param raw_data: Data to be processed.
|
||||
:param content_type: How to parse if raw_data is a string/bytes.
|
||||
:raises JSONDecodeError: If JSON is requested and parsing is impossible.
|
||||
:raises UnicodeDecodeError: If bytes is not UTF8
|
||||
:raises xml.etree.ElementTree.ParseError: If bytes is not valid XML
|
||||
"""
|
||||
# If response was asked as stream, do NOT read anything and quit now
|
||||
if kwargs.get("stream", True):
|
||||
return
|
||||
|
||||
http_response = response.http_response
|
||||
|
||||
response.context[self.CONTEXT_NAME] = self.deserialize_from_http_generics(
|
||||
http_response.text(),
|
||||
http_response.headers
|
||||
)
|
|
@ -23,6 +23,12 @@
|
|||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from .poller import LROPoller, NoPolling, PollingMethod
|
||||
import sys
|
||||
|
||||
__all__ = ['LROPoller', 'NoPolling', 'PollingMethod']
|
||||
from .poller import LROPoller, NoPolling, PollingMethod
|
||||
__all__ = ['LROPoller', 'NoPolling', 'PollingMethod']
|
||||
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
# Not executed on old Python, no syntax error
|
||||
from .async_poller import AsyncNoPolling, AsyncPollingMethod, async_poller
|
||||
__all__ += ['AsyncNoPolling', 'AsyncPollingMethod', 'async_poller']
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from .poller import NoPolling as _NoPolling
|
||||
|
||||
from ..serialization import Model
|
||||
from ..service_client import ServiceClient
|
||||
from ..pipeline import ClientRawResponse
|
||||
|
||||
class AsyncPollingMethod(object):
|
||||
"""ABC class for polling method.
|
||||
"""
|
||||
def initialize(self, client, initial_response, deserialization_callback):
|
||||
raise NotImplementedError("This method needs to be implemented")
|
||||
|
||||
async def run(self):
|
||||
raise NotImplementedError("This method needs to be implemented")
|
||||
|
||||
def status(self):
|
||||
raise NotImplementedError("This method needs to be implemented")
|
||||
|
||||
def finished(self):
|
||||
raise NotImplementedError("This method needs to be implemented")
|
||||
|
||||
def resource(self):
|
||||
raise NotImplementedError("This method needs to be implemented")
|
||||
|
||||
|
||||
class AsyncNoPolling(_NoPolling):
|
||||
"""An empty async poller that returns the deserialized initial response.
|
||||
"""
|
||||
async def run(self):
|
||||
"""Empty run, no polling.
|
||||
|
||||
Just override initial run to add "async"
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def async_poller(client, initial_response, deserialization_callback, polling_method):
|
||||
"""Async Poller for long running operations.
|
||||
|
||||
:param client: A msrest service client. Can be a SDK client and it will be casted to a ServiceClient.
|
||||
:type client: msrest.service_client.ServiceClient
|
||||
:param initial_response: The initial call response
|
||||
:type initial_response: msrest.universal_http.ClientResponse or msrest.pipeline.ClientRawResponse
|
||||
:param deserialization_callback: A callback that takes a Response and return a deserialized object. If a subclass of Model is given, this passes "deserialize" as callback.
|
||||
:type deserialization_callback: callable or msrest.serialization.Model
|
||||
:param polling_method: The polling strategy to adopt
|
||||
:type polling_method: msrest.polling.PollingMethod
|
||||
"""
|
||||
|
||||
try:
|
||||
client = client if isinstance(client, ServiceClient) else client._client
|
||||
except AttributeError:
|
||||
raise ValueError("Poller client parameter must be a low-level msrest Service Client or a SDK client.")
|
||||
response = initial_response.response if isinstance(initial_response, ClientRawResponse) else initial_response
|
||||
|
||||
if isinstance(deserialization_callback, type) and issubclass(deserialization_callback, Model):
|
||||
deserialization_callback = deserialization_callback.deserialize
|
||||
|
||||
# Might raise a CloudError
|
||||
polling_method.initialize(client, response, deserialization_callback)
|
||||
|
||||
await polling_method.run()
|
||||
return polling_method.resource()
|
|
@ -43,6 +43,8 @@ import isodate
|
|||
|
||||
from typing import Dict, Any
|
||||
|
||||
from .pipeline.universal import RawDeserializer
|
||||
|
||||
from .exceptions import (
|
||||
ValidationError,
|
||||
SerializationError,
|
||||
|
@ -1335,77 +1337,47 @@ class Deserializer(object):
|
|||
pass # Target is not a Model, no classify
|
||||
return target, target.__class__.__name__
|
||||
|
||||
JSON_MIMETYPES = [
|
||||
'application/json',
|
||||
'text/json' # Because we're open minded people...
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpack_content(raw_data, content_type=None):
|
||||
"""Extract data from the body of a REST response object.
|
||||
"""Extract the correct structure for deserialization.
|
||||
|
||||
If raw_data is a requests.Response object, follow Content-Type
|
||||
to parse (ignore content_type parameter).
|
||||
If bytes is given, decode using UTF8 first.
|
||||
If content_type is given, try to parse.
|
||||
Otherwise, return initial data.
|
||||
We assume everything is UTF8 (BOM acceptable).
|
||||
If raw_data is a PipelineResponse, try to extract the result of RawDeserializer.
|
||||
if we can't, raise. Your Pipeline should have a RawDeserializer.
|
||||
|
||||
If not a pipeline response and raw_data is bytes or string, use content-type
|
||||
to decode it. If no content-type, try JSON.
|
||||
|
||||
If raw_data is something else, bypass all logic and return it directly.
|
||||
|
||||
:param raw_data: Data to be processed.
|
||||
:param content_type: How to parse if raw_data is a string/bytes.
|
||||
:raises JSONDecodeError: If JSON is requested and parsing is impossible.
|
||||
:raises UnicodeDecodeError: If bytes is not UTF8
|
||||
"""
|
||||
# Assume this is enough to detect a Pipeline Response without importing it
|
||||
context = getattr(raw_data, "context", {})
|
||||
if context:
|
||||
if RawDeserializer.CONTEXT_NAME in context:
|
||||
return context[RawDeserializer.CONTEXT_NAME]
|
||||
raise ValueError("This pipeline didn't have the RawDeserializer policy; can't deserialize")
|
||||
|
||||
if hasattr(raw_data, 'text'): # Our requests.Response test
|
||||
# Try to use content-type from headers if available
|
||||
if 'content-type' in raw_data.headers:
|
||||
content_type = raw_data.headers['content-type'].split(";")[0].strip().lower()
|
||||
# Ouch, this server did not declare what it sent...
|
||||
# Use Swagger "produces", which will be passed to "content_type" here
|
||||
# If "content_type" also is empty, this means that it's an old version
|
||||
# of Autorest for Python, let's guess it's JSON...
|
||||
# Also, since Autorest was considering that an empty body was a valid JSON,
|
||||
# need that test as well....
|
||||
elif not content_type:
|
||||
if not raw_data.text:
|
||||
return None
|
||||
content_type = "application/json"
|
||||
# Whatever content type, data is readable (not bytes). Get it as a string.
|
||||
data = raw_data.text
|
||||
elif raw_data and isinstance(raw_data, bytes):
|
||||
data = raw_data.decode(encoding='utf-8-sig')
|
||||
else:
|
||||
data = raw_data
|
||||
#Assume this is enough to recognize universal_http.ClientResponse without importing it
|
||||
if hasattr(raw_data, "body"):
|
||||
return RawDeserializer.deserialize_from_http_generics(
|
||||
raw_data.text(),
|
||||
raw_data.headers
|
||||
)
|
||||
|
||||
if content_type in Deserializer.JSON_MIMETYPES:
|
||||
try:
|
||||
return json.loads(data)
|
||||
except ValueError as err:
|
||||
raise DeserializationError("JSON is invalid: {}".format(err), err)
|
||||
elif "xml" in (content_type or []):
|
||||
try:
|
||||
return ET.fromstring(data)
|
||||
except ET.ParseError:
|
||||
# It might be because the server has an issue, and returned JSON with
|
||||
# content-type XML....
|
||||
# So let's try a JSON load, and if it's still broken
|
||||
# let's flow the initial exception
|
||||
def _json_attemp(data):
|
||||
try:
|
||||
return True, json.loads(data)
|
||||
except ValueError:
|
||||
return False, None # Don't care about this one
|
||||
success, data = _json_attemp(data)
|
||||
if success:
|
||||
return data
|
||||
# If i'm here, it's not JSON, it's not XML, let's scream
|
||||
# and raise the last context in this block (the XML exception)
|
||||
# The function hack is because Py2.7 messes up with exception
|
||||
# context otherwise.
|
||||
_LOGGER.critical("Wasn't XML not JSON, failing")
|
||||
raise
|
||||
return data
|
||||
# Assume this enough to recognize requests.Response without importing it.
|
||||
if hasattr(raw_data, '_content_consumed'):
|
||||
return RawDeserializer.deserialize_from_http_generics(
|
||||
raw_data.text,
|
||||
raw_data.headers
|
||||
)
|
||||
|
||||
if isinstance(raw_data, (basestring, bytes)) or hasattr(raw_data, 'read'):
|
||||
return RawDeserializer.deserialize_from_text(raw_data, content_type)
|
||||
return raw_data
|
||||
|
||||
def _instantiate_model(self, response, attrs, additional_properties=None):
|
||||
"""Instantiate a response model passing in deserialized args.
|
||||
|
|
|
@ -24,35 +24,54 @@
|
|||
#
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
try:
|
||||
from urlparse import urljoin, urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import warnings
|
||||
|
||||
from typing import Any, Dict, Union, IO, Tuple, Optional, cast, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration import Configuration
|
||||
|
||||
from oauthlib import oauth2
|
||||
import requests.adapters
|
||||
from typing import List, Any, Dict, Union, IO, Tuple, Optional, Callable, Iterator, cast, TYPE_CHECKING # pylint: disable=unused-import
|
||||
|
||||
from .authentication import Authentication
|
||||
from .pipeline import ClientRequest
|
||||
from .http_logger import log_request, log_response
|
||||
from .exceptions import (
|
||||
TokenExpiredError,
|
||||
ClientRequestError,
|
||||
raise_with_traceback)
|
||||
from .universal_http import ClientRequest, ClientResponse
|
||||
from .universal_http.requests import (
|
||||
RequestsHTTPSender,
|
||||
)
|
||||
from .pipeline import Request, Pipeline, HTTPPolicy, SansIOHTTPPolicy
|
||||
from .pipeline.requests import (
|
||||
PipelineRequestsHTTPSender,
|
||||
RequestsCredentialsPolicy,
|
||||
RequestsPatchSession
|
||||
)
|
||||
from .pipeline.universal import (
|
||||
HTTPLogger,
|
||||
RawDeserializer
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration import Configuration # pylint: disable=unused-import
|
||||
from .universal_http.requests import RequestsClientResponse # pylint: disable=unused-import
|
||||
import requests # pylint: disable=unused-import
|
||||
|
||||
if sys.version_info >= (3, 5, 2):
|
||||
# Not executed on old Python, no syntax error
|
||||
from .async_client import AsyncServiceClientMixin, AsyncSDKClientMixin # type: ignore
|
||||
else:
|
||||
class AsyncSDKClientMixin(object): # type: ignore
|
||||
pass
|
||||
|
||||
class AsyncServiceClientMixin(object): # type: ignore
|
||||
def __init__(self, creds, config):
|
||||
pass
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class SDKClient(object):
|
||||
class SDKClient(AsyncSDKClientMixin):
|
||||
"""The base class of all generated SDK client.
|
||||
"""
|
||||
def __init__(self, creds, config):
|
||||
|
@ -73,121 +92,8 @@ class SDKClient(object):
|
|||
def __exit__(self, *exc_details):
|
||||
self._client.__exit__(*exc_details)
|
||||
|
||||
class _RequestsHTTPDriver(object):
|
||||
|
||||
_protocols = ['http://', 'https://']
|
||||
|
||||
def __init__(self, config):
|
||||
# type: (Configuration) -> None
|
||||
self.config = config
|
||||
self.session = requests.Session()
|
||||
|
||||
def __enter__(self):
|
||||
# type: () -> _RequestsHTTPDriver
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
|
||||
def configure_session(self, **config):
|
||||
# type: (str) -> Dict[str, Any]
|
||||
"""Apply configuration to session.
|
||||
|
||||
:param config: Specific configuration overrides.
|
||||
:rtype: dict
|
||||
:return: A dict that will be kwarg-send to session.request
|
||||
"""
|
||||
kwargs = self.config.connection() # type: Dict[str, Any]
|
||||
for opt in ['timeout', 'verify', 'cert']:
|
||||
kwargs[opt] = config.get(opt, kwargs[opt])
|
||||
kwargs.update({k:config[k] for k in ['cookies'] if k in config})
|
||||
kwargs['allow_redirects'] = config.get(
|
||||
'allow_redirects', bool(self.config.redirect_policy))
|
||||
|
||||
kwargs['headers'] = self.config.headers.copy()
|
||||
kwargs['headers']['User-Agent'] = self.config.user_agent
|
||||
proxies = config.get('proxies', self.config.proxies())
|
||||
if proxies:
|
||||
kwargs['proxies'] = proxies
|
||||
|
||||
kwargs['stream'] = config.get('stream', True)
|
||||
|
||||
self.session.max_redirects = int(config.get('max_redirects', self.config.redirect_policy()))
|
||||
self.session.trust_env = bool(config.get('use_env_proxies', self.config.proxies.use_env_settings))
|
||||
|
||||
# Patch the redirect method directly *if not done already*
|
||||
if not getattr(self.session.resolve_redirects, 'is_mrest_patched', False):
|
||||
redirect_logic = self.session.resolve_redirects
|
||||
|
||||
def wrapped_redirect(resp, req, **kwargs):
|
||||
attempt = self.config.redirect_policy.check_redirect(resp, req)
|
||||
return redirect_logic(resp, req, **kwargs) if attempt else []
|
||||
wrapped_redirect.is_mrest_patched = True # type: ignore
|
||||
|
||||
self.session.resolve_redirects = wrapped_redirect # type: ignore
|
||||
|
||||
# if "enable_http_logger" is defined at the operation level, take the value.
|
||||
# if not, take the one in the client config
|
||||
# if not, disable http_logger
|
||||
hooks = []
|
||||
if config.get("enable_http_logger", self.config.enable_http_logger):
|
||||
def log_hook(r, *args, **kwargs):
|
||||
log_request(None, r.request)
|
||||
log_response(None, r.request, r, result=r)
|
||||
hooks.append(log_hook)
|
||||
|
||||
def make_user_hook_cb(user_hook, session):
|
||||
def user_hook_cb(r, *args, **kwargs):
|
||||
kwargs.setdefault("msrest", {})['session'] = session
|
||||
return user_hook(r, *args, **kwargs)
|
||||
return user_hook_cb
|
||||
|
||||
for user_hook in self.config.hooks:
|
||||
hooks.append(make_user_hook_cb(user_hook, self.session))
|
||||
|
||||
if hooks:
|
||||
kwargs['hooks'] = {'response': hooks}
|
||||
|
||||
# Change max_retries in current all installed adapters
|
||||
max_retries = config.get('retries', self.config.retry_policy())
|
||||
for protocol in self._protocols:
|
||||
self.session.adapters[protocol].max_retries=max_retries
|
||||
|
||||
output_kwargs = self.config.session_configuration_callback(
|
||||
self.session,
|
||||
self.config,
|
||||
config,
|
||||
**kwargs
|
||||
)
|
||||
if output_kwargs is not None:
|
||||
kwargs = output_kwargs
|
||||
|
||||
return kwargs
|
||||
|
||||
def send(self, request, **config):
|
||||
# type: (ClientRequest, Any) -> requests.Response
|
||||
"""Send request object according to configuration.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
:param config: Any specific config overrides
|
||||
"""
|
||||
kwargs = config.copy()
|
||||
if request.files:
|
||||
kwargs['files'] = request.files
|
||||
elif request.data:
|
||||
kwargs['data'] = request.data
|
||||
kwargs.setdefault("headers", {}).update(request.headers)
|
||||
|
||||
response = self.session.request(
|
||||
request.method,
|
||||
request.url,
|
||||
**kwargs)
|
||||
return response
|
||||
|
||||
class ServiceClient(object):
|
||||
class ServiceClient(AsyncServiceClientMixin):
|
||||
"""REST Service Client.
|
||||
Maintains client pipeline and handles all requests and responses.
|
||||
|
||||
|
@ -197,47 +103,56 @@ class ServiceClient(object):
|
|||
|
||||
def __init__(self, creds, config):
|
||||
# type: (Any, Configuration) -> None
|
||||
if config is None:
|
||||
raise ValueError("Config is a required parameter")
|
||||
self.config = config
|
||||
self.creds = creds if creds else Authentication()
|
||||
self._http_driver = _RequestsHTTPDriver(config)
|
||||
self._creds = creds
|
||||
# Call the mixin AFTER self.config and self._creds
|
||||
super(ServiceClient, self).__init__(creds, config)
|
||||
|
||||
# "pipeline" be should accessible from "config"
|
||||
# In legacy mode this is weird, this config is a parameter of "pipeline"
|
||||
# Should be revamp one day.
|
||||
self.config.pipeline = self._create_default_pipeline()
|
||||
|
||||
def _create_default_pipeline(self):
|
||||
# type: () -> Pipeline[ClientRequest, RequestsClientResponse]
|
||||
|
||||
policies = [
|
||||
self.config.user_agent_policy, # UserAgent policy
|
||||
RequestsPatchSession(), # Support deprecated operation config at the session level
|
||||
self.config.http_logger_policy # HTTP request/response log
|
||||
] # type: List[Union[HTTPPolicy, SansIOHTTPPolicy]]
|
||||
if self._creds:
|
||||
if isinstance(self._creds, (HTTPPolicy, SansIOHTTPPolicy)):
|
||||
policies.insert(1, self._creds)
|
||||
else:
|
||||
# Assume this is the old credentials class, and then requests. Wrap it.
|
||||
policies.insert(1, RequestsCredentialsPolicy(self._creds)) # Set credentials for requests based session
|
||||
|
||||
return Pipeline(
|
||||
policies,
|
||||
PipelineRequestsHTTPSender(RequestsHTTPSender(self.config)) # Send HTTP request using requests
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# type: () -> ServiceClient
|
||||
self.config.keep_alive = True
|
||||
self._http_driver.__enter__()
|
||||
self.config.pipeline.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
self._http_driver.__exit__(*exc_details)
|
||||
self.config.pipeline.__exit__(*exc_details)
|
||||
self.config.keep_alive = False
|
||||
|
||||
def close(self):
|
||||
# type: () -> None
|
||||
"""Close the session if keep_alive is True.
|
||||
"""Close the pipeline if keep_alive is True.
|
||||
"""
|
||||
self._http_driver.close()
|
||||
self.config.pipeline.__exit__()
|
||||
|
||||
def _format_data(self, data):
|
||||
# type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]
|
||||
"""Format field data according to whether it is a stream or
|
||||
a string for a form-data request.
|
||||
|
||||
:param data: The request field data.
|
||||
:type data: str or file-like object.
|
||||
"""
|
||||
if hasattr(data, 'read'):
|
||||
data = cast(IO, data)
|
||||
data_name = None
|
||||
try:
|
||||
if data.name[0] != '<' and data.name[-1] != '>':
|
||||
data_name = os.path.basename(data.name)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return (data_name, data, "application/octet-stream")
|
||||
return (None, cast(str, data))
|
||||
|
||||
def _request(self, url, params, headers, content, form_content):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def _request(self, method, url, params, headers, content, form_content):
|
||||
# type: (str, str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create ClientRequest object.
|
||||
|
||||
:param str url: URL for the request.
|
||||
|
@ -245,10 +160,7 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = ClientRequest()
|
||||
|
||||
if url:
|
||||
request.url = self.format_url(url)
|
||||
request = ClientRequest(method, self.format_url(url))
|
||||
|
||||
if params:
|
||||
request.format_parameters(params)
|
||||
|
@ -266,31 +178,10 @@ class ServiceClient(object):
|
|||
request.add_content(content)
|
||||
|
||||
if form_content:
|
||||
self._add_formdata(request, form_content)
|
||||
request.add_formdata(form_content)
|
||||
|
||||
return request
|
||||
|
||||
def _add_formdata(self, request, content=None):
|
||||
# type: (ClientRequest, Optional[Dict[str, str]]) -> None
|
||||
"""Add data as a multipart form-data request to the request.
|
||||
|
||||
We only deal with file-like objects or strings at this point.
|
||||
The requests is not yet streamed.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
:param dict headers: Any headers to add to the request.
|
||||
:param dict content: Dictionary of the fields of the formdata.
|
||||
"""
|
||||
if content is None:
|
||||
content = {}
|
||||
content_type = request.headers.pop('Content-Type', None) if request.headers else None
|
||||
|
||||
if content_type and content_type.lower() == 'application/x-www-form-urlencoded':
|
||||
# Do NOT use "add_content" that assumes input is JSON
|
||||
request.data = {f: d for f, d in content.items() if d is not None}
|
||||
else: # Assume "multipart/form-data"
|
||||
request.files = {f: self._format_data(d) for f, d in content.items() if d is not None}
|
||||
|
||||
def send_formdata(self, request, headers=None, content=None, **config):
|
||||
"""Send data as a multipart form-data request.
|
||||
We only deal with file-like objects or strings at this point.
|
||||
|
@ -304,10 +195,10 @@ class ServiceClient(object):
|
|||
:param config: Any specific config overrides.
|
||||
"""
|
||||
request.headers = headers
|
||||
self._add_formdata(request, content)
|
||||
request.add_formdata(content)
|
||||
return self.send(request, **config)
|
||||
|
||||
def send(self, request, headers=None, content=None, **config):
|
||||
def send(self, request, headers=None, content=None, **kwargs):
|
||||
"""Prepare and send request object according to configuration.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
|
@ -315,92 +206,54 @@ class ServiceClient(object):
|
|||
:param content: Any body data to add to the request.
|
||||
:param config: Any specific config overrides
|
||||
"""
|
||||
if self.config.keep_alive:
|
||||
http_driver = self._http_driver
|
||||
else:
|
||||
http_driver = _RequestsHTTPDriver(self.config)
|
||||
|
||||
try:
|
||||
self.creds.signed_session(http_driver.session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
http_driver.session = self.creds.signed_session()
|
||||
if http_driver is self._http_driver:
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
|
||||
kwargs = http_driver.configure_session(**config)
|
||||
|
||||
# "content" and "headers" are deprecated, only old SDK
|
||||
if headers:
|
||||
request.headers.update(headers)
|
||||
if not request.files and request.data == [] and content is not None:
|
||||
if not request.files and request.data is None and content is not None:
|
||||
request.add_content(content)
|
||||
# End of deprecation
|
||||
|
||||
response = None
|
||||
kwargs.setdefault('stream', True)
|
||||
try:
|
||||
try:
|
||||
response = http_driver.send(request, **kwargs)
|
||||
return response
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
error = "Token expired or is invalid. Attempting to refresh."
|
||||
_LOGGER.warning(error)
|
||||
|
||||
try:
|
||||
try:
|
||||
self.creds.refresh_session(http_driver.session)
|
||||
except TypeError: # Credentials does not support session injection
|
||||
http_driver.session = self.creds.refresh_session()
|
||||
if http_driver is self._http_driver:
|
||||
_LOGGER.warning("Your credentials class does not support session injection. Performance will not be at the maximum.")
|
||||
# Only reconfigure on refresh if it's a new session
|
||||
kwargs = http_driver.configure_session(**config)
|
||||
|
||||
response = http_driver.send(request, **kwargs)
|
||||
return response
|
||||
except (oauth2.rfc6749.errors.InvalidGrantError,
|
||||
oauth2.rfc6749.errors.TokenExpiredError) as err:
|
||||
msg = "Token expired or is invalid."
|
||||
raise_with_traceback(TokenExpiredError, msg, err)
|
||||
|
||||
except (requests.RequestException,
|
||||
oauth2.rfc6749.errors.OAuth2Error) as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
pipeline_response = self.config.pipeline.run(request, **kwargs)
|
||||
# There is too much thing that expects this method to return a "requests.Response"
|
||||
# to break it in a compatible release.
|
||||
# Also, to be pragmatic in the "sync" world "requests" rules anyway.
|
||||
# However, attach the Universal HTTP response
|
||||
# to get the streaming generator.
|
||||
response = pipeline_response.http_response.internal_response
|
||||
response._universal_http_response = pipeline_response.http_response
|
||||
response.context = pipeline_response.context
|
||||
return response
|
||||
finally:
|
||||
self._close_local_session_if_necessary(response, http_driver, kwargs['stream'])
|
||||
self._close_local_session_if_necessary(response, kwargs['stream'])
|
||||
|
||||
def _close_local_session_if_necessary(self, response, http_driver, stream):
|
||||
# Do NOT close session if using my own HTTP driver. No exception.
|
||||
if self._http_driver is http_driver:
|
||||
return
|
||||
def _close_local_session_if_necessary(self, response, stream):
|
||||
# Here, it's a local session, I might close it.
|
||||
if not response or not stream:
|
||||
http_driver.session.close()
|
||||
if not self.config.keep_alive and (not response or not stream):
|
||||
self.config.pipeline._sender.driver.session.close()
|
||||
|
||||
def stream_download(self, data, callback):
|
||||
# type: (Union[requests.Response, ClientResponse], Callable) -> Iterator[bytes]
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
:param data: A response object to be streamed.
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
"""
|
||||
block = self.config.connection.data_block_size
|
||||
if not data._content_consumed:
|
||||
with contextlib.closing(data) as response:
|
||||
for chunk in response.iter_content(block):
|
||||
if not chunk:
|
||||
break
|
||||
if callback and callable(callback):
|
||||
callback(chunk, response=response)
|
||||
yield chunk
|
||||
else:
|
||||
for chunk in data.iter_content(block):
|
||||
if not chunk:
|
||||
break
|
||||
if callback and callable(callback):
|
||||
callback(chunk, response=data)
|
||||
yield chunk
|
||||
data.close()
|
||||
try:
|
||||
# Assume this is ClientResponse, which it should be if backward compat was not important
|
||||
return cast(ClientResponse, data).stream_download(block, callback)
|
||||
except AttributeError:
|
||||
try:
|
||||
# Assume this is the patched requests.Response from "send"
|
||||
return data._universal_http_response.stream_download(block, callback) # type: ignore
|
||||
except AttributeError:
|
||||
# Assume this is a raw requests.Response
|
||||
from .universal_http.requests import RequestsClientResponse
|
||||
response = RequestsClientResponse(None, data)
|
||||
return response.stream_download(block, callback)
|
||||
|
||||
def stream_upload(self, data, callback):
|
||||
"""Generator for streaming request body data.
|
||||
|
@ -446,8 +299,8 @@ class ServiceClient(object):
|
|||
DeprecationWarning)
|
||||
self.config.headers[header] = value
|
||||
|
||||
def get(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def get(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a GET request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -455,12 +308,12 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request = self._request('GET', url, params, headers, content, form_content)
|
||||
request.method = 'GET'
|
||||
return request
|
||||
|
||||
def put(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def put(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a PUT request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -468,12 +321,11 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'PUT'
|
||||
request = self._request('PUT', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
||||
def post(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def post(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a POST request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -481,12 +333,11 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'POST'
|
||||
request = self._request('POST', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
||||
def head(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def head(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a HEAD request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -494,12 +345,11 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'HEAD'
|
||||
request = self._request('HEAD', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
||||
def patch(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def patch(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a PATCH request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -507,12 +357,11 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'PATCH'
|
||||
request = self._request('PATCH', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
||||
def delete(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def delete(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a DELETE request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -520,12 +369,11 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'DELETE'
|
||||
request = self._request('DELETE', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
||||
def merge(self, url=None, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (Optional[str], Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
def merge(self, url, params=None, headers=None, content=None, form_content=None):
|
||||
# type: (str, Optional[Dict[str, str]], Optional[Dict[str, str]], Any, Optional[Dict[str, Any]]) -> ClientRequest
|
||||
"""Create a MERGE request object.
|
||||
|
||||
:param str url: The request URL.
|
||||
|
@ -533,6 +381,5 @@ class ServiceClient(object):
|
|||
:param dict headers: Headers
|
||||
:param dict form_content: Form content
|
||||
"""
|
||||
request = self._request(url, params, headers, content, form_content)
|
||||
request.method = 'MERGE'
|
||||
request = self._request('MERGE', url, params, headers, content, form_content)
|
||||
return request
|
||||
|
|
|
@ -0,0 +1,458 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
|
||||
import abc
|
||||
try:
|
||||
import configparser
|
||||
from configparser import NoOptionError
|
||||
except ImportError:
|
||||
import ConfigParser as configparser # type: ignore
|
||||
from ConfigParser import NoOptionError # type: ignore
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar, cast, IO, List, Union, Any, Mapping, Dict, Optional, Tuple, Callable, Iterator # pylint: disable=unused-import
|
||||
|
||||
HTTPResponseType = TypeVar("HTTPResponseType", bound='HTTPClientResponse')
|
||||
|
||||
# This file is NOT using any "requests" HTTP implementation
|
||||
# However, the CaseInsensitiveDict is handy.
|
||||
# If one day we reach the point where "requests" can be skip totally,
|
||||
# might provide our own implementation
|
||||
from requests.structures import CaseInsensitiveDict
|
||||
|
||||
from ..exceptions import ClientRequestError, raise_with_traceback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..serialization import Model # pylint: disable=unused-import
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
ABC = abc.ABC
|
||||
except AttributeError: # Python 2.7, abc exists, but not ABC
|
||||
ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) # type: ignore
|
||||
|
||||
try:
|
||||
from contextlib import AbstractContextManager # type: ignore
|
||||
except ImportError: # Python <= 3.5
|
||||
class AbstractContextManager(object): # type: ignore
|
||||
def __enter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
|
||||
class HTTPSender(AbstractContextManager, ABC):
|
||||
"""An http sender ABC.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def send(self, request, **config):
|
||||
# type: (ClientRequest, Any) -> ClientResponse
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HTTPSenderConfiguration(object):
|
||||
"""HTTP sender configuration.
|
||||
|
||||
This is composed of generic HTTP configuration, and could be use as a common
|
||||
HTTP configuration format.
|
||||
|
||||
:param str filepath: Path to existing config file (optional).
|
||||
"""
|
||||
|
||||
def __init__(self, filepath=None):
|
||||
# Communication configuration
|
||||
self.connection = ClientConnection()
|
||||
|
||||
# Headers (sent with every requests)
|
||||
self.headers = {} # type: Dict[str, str]
|
||||
|
||||
# ProxyConfiguration
|
||||
self.proxies = ClientProxies()
|
||||
|
||||
# Redirect configuration
|
||||
self.redirect_policy = ClientRedirectPolicy()
|
||||
|
||||
self._config = configparser.ConfigParser()
|
||||
self._config.optionxform = str # type: ignore
|
||||
|
||||
if filepath:
|
||||
self.load(filepath)
|
||||
|
||||
def _clear_config(self):
|
||||
# type: () -> None
|
||||
"""Clearout config object in memory."""
|
||||
for section in self._config.sections():
|
||||
self._config.remove_section(section)
|
||||
|
||||
def save(self, filepath):
|
||||
# type: (str) -> None
|
||||
"""Save current configuration to file.
|
||||
|
||||
:param str filepath: Path to file where settings will be saved.
|
||||
:raises: ValueError if supplied filepath cannot be written to.
|
||||
"""
|
||||
sections = [
|
||||
"Connection",
|
||||
"Proxies",
|
||||
"RedirectPolicy"]
|
||||
for section in sections:
|
||||
self._config.add_section(section)
|
||||
|
||||
self._config.set("Connection", "timeout", self.connection.timeout)
|
||||
self._config.set("Connection", "verify", self.connection.verify)
|
||||
self._config.set("Connection", "cert", self.connection.cert)
|
||||
|
||||
self._config.set("Proxies", "proxies", self.proxies.proxies)
|
||||
self._config.set("Proxies", "env_settings",
|
||||
self.proxies.use_env_settings)
|
||||
|
||||
self._config.set("RedirectPolicy", "allow", self.redirect_policy.allow)
|
||||
self._config.set("RedirectPolicy", "max_redirects",
|
||||
self.redirect_policy.max_redirects)
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as configfile:
|
||||
self._config.write(configfile)
|
||||
except (KeyError, EnvironmentError):
|
||||
error = "Supplied config filepath invalid."
|
||||
raise_with_traceback(ValueError, error)
|
||||
finally:
|
||||
self._clear_config()
|
||||
|
||||
def load(self, filepath):
|
||||
# type: (str) -> None
|
||||
"""Load configuration from existing file.
|
||||
|
||||
:param str filepath: Path to existing config file.
|
||||
:raises: ValueError if supplied config file is invalid.
|
||||
"""
|
||||
try:
|
||||
self._config.read(filepath)
|
||||
import ast
|
||||
self.connection.timeout = \
|
||||
self._config.getint("Connection", "timeout")
|
||||
self.connection.verify = \
|
||||
self._config.getboolean("Connection", "verify")
|
||||
self.connection.cert = \
|
||||
self._config.get("Connection", "cert")
|
||||
|
||||
self.proxies.proxies = \
|
||||
ast.literal_eval(self._config.get("Proxies", "proxies"))
|
||||
self.proxies.use_env_settings = \
|
||||
self._config.getboolean("Proxies", "env_settings")
|
||||
|
||||
self.redirect_policy.allow = \
|
||||
self._config.getboolean("RedirectPolicy", "allow")
|
||||
self.redirect_policy.max_redirects = \
|
||||
self._config.getint("RedirectPolicy", "max_redirects")
|
||||
|
||||
except (ValueError, EnvironmentError, NoOptionError):
|
||||
error = "Supplied config file incompatible."
|
||||
raise_with_traceback(ValueError, error)
|
||||
finally:
|
||||
self._clear_config()
|
||||
|
||||
|
||||
class ClientRequest(object):
|
||||
"""Represents a HTTP request.
|
||||
|
||||
URL can be given without query parameters, to be added later using "format_parameters".
|
||||
|
||||
Instance can be created without data, to be added later using "add_content"
|
||||
|
||||
Instance can be created without files, to be added later using "add_formdata"
|
||||
|
||||
:param str method: HTTP method (GET, HEAD, etc.)
|
||||
:param str url: At least complete scheme/host/path
|
||||
:param dict[str,str] headers: HTTP headers
|
||||
:param files: Files list.
|
||||
:param data: Body to be sent.
|
||||
:type data: bytes or str.
|
||||
"""
|
||||
def __init__(self, method, url, headers=None, files=None, data=None):
|
||||
# type: (str, str, Mapping[str, str], Any, Any) -> None
|
||||
self.method = method
|
||||
self.url = url
|
||||
self.headers = CaseInsensitiveDict(headers)
|
||||
self.files = files
|
||||
self.data = data
|
||||
|
||||
def __repr__(self):
|
||||
return '<ClientRequest [%s]>' % (self.method)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Alias to data."""
|
||||
return self.data
|
||||
|
||||
@body.setter
|
||||
def body(self, value):
|
||||
self.data = value
|
||||
|
||||
def format_parameters(self, params):
|
||||
# type: (Dict[str, str]) -> None
|
||||
"""Format parameters into a valid query string.
|
||||
It's assumed all parameters have already been quoted as
|
||||
valid URL strings.
|
||||
|
||||
:param dict params: A dictionary of parameters.
|
||||
"""
|
||||
query = urlparse(self.url).query
|
||||
if query:
|
||||
self.url = self.url.partition('?')[0]
|
||||
existing_params = {
|
||||
p[0]: p[-1]
|
||||
for p in [p.partition('=') for p in query.split('&')]
|
||||
}
|
||||
params.update(existing_params)
|
||||
query_params = ["{}={}".format(k, v) for k, v in params.items()]
|
||||
query = '?' + '&'.join(query_params)
|
||||
self.url = self.url + query
|
||||
|
||||
def add_content(self, data):
|
||||
# type: (Optional[Union[Dict[str, Any], ET.Element]]) -> None
|
||||
"""Add a body to the request.
|
||||
|
||||
:param data: Request body data, can be a json serializable
|
||||
object (e.g. dictionary) or a generator (e.g. file data).
|
||||
"""
|
||||
if data is None:
|
||||
return
|
||||
|
||||
if isinstance(data, ET.Element):
|
||||
bytes_data = ET.tostring(data, encoding="utf8")
|
||||
self.headers['Content-Length'] = str(len(bytes_data))
|
||||
self.data = bytes_data
|
||||
return
|
||||
|
||||
# By default, assume JSON
|
||||
try:
|
||||
self.data = json.dumps(data)
|
||||
self.headers['Content-Length'] = str(len(self.data))
|
||||
except TypeError:
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _format_data(data):
|
||||
# type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]]
|
||||
"""Format field data according to whether it is a stream or
|
||||
a string for a form-data request.
|
||||
|
||||
:param data: The request field data.
|
||||
:type data: str or file-like object.
|
||||
"""
|
||||
if hasattr(data, 'read'):
|
||||
data = cast(IO, data)
|
||||
data_name = None
|
||||
try:
|
||||
if data.name[0] != '<' and data.name[-1] != '>':
|
||||
data_name = os.path.basename(data.name)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return (data_name, data, "application/octet-stream")
|
||||
return (None, cast(str, data))
|
||||
|
||||
def add_formdata(self, content=None):
|
||||
# type: (Optional[Dict[str, str]]) -> None
|
||||
"""Add data as a multipart form-data request to the request.
|
||||
|
||||
We only deal with file-like objects or strings at this point.
|
||||
The requests is not yet streamed.
|
||||
|
||||
:param dict headers: Any headers to add to the request.
|
||||
:param dict content: Dictionary of the fields of the formdata.
|
||||
"""
|
||||
if content is None:
|
||||
content = {}
|
||||
content_type = self.headers.pop('Content-Type', None) if self.headers else None
|
||||
|
||||
if content_type and content_type.lower() == 'application/x-www-form-urlencoded':
|
||||
# Do NOT use "add_content" that assumes input is JSON
|
||||
self.data = {f: d for f, d in content.items() if d is not None}
|
||||
else: # Assume "multipart/form-data"
|
||||
self.files = {f: self._format_data(d) for f, d in content.items() if d is not None}
|
||||
|
||||
class HTTPClientResponse(object):
|
||||
"""Represent a HTTP response.
|
||||
|
||||
No body is defined here on purpose, since async pipeline
|
||||
will provide async ways to access the body
|
||||
|
||||
You have two differents types of body:
|
||||
- Full in-memory using "body" as bytes
|
||||
"""
|
||||
def __init__(self, request, internal_response):
|
||||
# type: (ClientRequest, Any) -> None
|
||||
self.request = request
|
||||
self.internal_response = internal_response
|
||||
self.status_code = None # type: Optional[int]
|
||||
self.headers = {} # type: Dict[str, str]
|
||||
self.reason = None # type: Optional[str]
|
||||
|
||||
def body(self):
|
||||
# type: () -> bytes
|
||||
"""Return the whole body as bytes in memory.
|
||||
"""
|
||||
pass
|
||||
|
||||
def text(self, encoding=None):
|
||||
# type: (str) -> str
|
||||
"""Return the whole body as a string.
|
||||
|
||||
:param str encoding: The encoding to apply. If None, use "utf-8".
|
||||
Implementation can be smarter if they want (using headers).
|
||||
"""
|
||||
return self.body().decode(encoding or "utf-8")
|
||||
|
||||
def raise_for_status(self):
|
||||
"""Raise for status. Should be overriden, but basic implementation provided.
|
||||
"""
|
||||
if self.status_code >= 400:
|
||||
raise ClientRequestError("Received status code {}".format(self.status_code))
|
||||
|
||||
|
||||
class ClientResponse(HTTPClientResponse):
|
||||
|
||||
def stream_download(self, chunk_size=None, callback=None):
|
||||
# type: (Optional[int], Optional[Callable]) -> Iterator[bytes]
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
Should be implemented by sub-classes if streaming download
|
||||
is supported.
|
||||
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
:param int chunk_size:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientRedirectPolicy(object):
|
||||
"""Redirect configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.allow = True
|
||||
self.max_redirects = 30
|
||||
|
||||
def __bool__(self):
|
||||
# type: () -> bool
|
||||
"""Whether redirects are allowed."""
|
||||
return self.allow
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> int
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = "Configuring redirects: allow=%r, max=%r"
|
||||
_LOGGER.debug(debug, self.allow, self.max_redirects)
|
||||
return self.max_redirects
|
||||
|
||||
|
||||
class ClientProxies(object):
|
||||
"""Proxy configuration settings.
|
||||
Proxies can also be configured using HTTP_PROXY and HTTPS_PROXY
|
||||
environment variables, in which case set use_env_settings to True.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.proxies = {}
|
||||
self.use_env_settings = True
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Dict[str, str]
|
||||
"""Return configuration to be applied to connection."""
|
||||
proxy_string = "\n".join(
|
||||
[" {}: {}".format(k, v) for k, v in self.proxies.items()])
|
||||
|
||||
_LOGGER.debug("Configuring proxies: %r", proxy_string)
|
||||
debug = "Evaluate proxies against ENV settings: %r"
|
||||
_LOGGER.debug(debug, self.use_env_settings)
|
||||
return self.proxies
|
||||
|
||||
def add(self, protocol, proxy_url):
|
||||
# type: (str, str) -> None
|
||||
"""Add proxy.
|
||||
|
||||
:param str protocol: Protocol for which proxy is to be applied. Can
|
||||
be 'http', 'https', etc. Can also include host.
|
||||
:param str proxy_url: The proxy URL. Where basic auth is required,
|
||||
use the format: http://user:password@host
|
||||
"""
|
||||
self.proxies[protocol] = proxy_url
|
||||
|
||||
|
||||
class ClientConnection(object):
|
||||
"""Request connection configuration settings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.timeout = 100
|
||||
self.verify = True
|
||||
self.cert = None
|
||||
self.data_block_size = 4096
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Dict[str, Union[str, int]]
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = "Configuring request: timeout=%r, verify=%r, cert=%r"
|
||||
_LOGGER.debug(debug, self.timeout, self.verify, self.cert)
|
||||
return {'timeout': self.timeout,
|
||||
'verify': self.verify,
|
||||
'cert': self.cert}
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ClientRequest',
|
||||
'ClientResponse',
|
||||
'HTTPSender',
|
||||
# Generic HTTP configuration
|
||||
'HTTPSenderConfiguration',
|
||||
'ClientRedirectPolicy',
|
||||
'ClientProxies',
|
||||
'ClientConnection'
|
||||
]
|
||||
|
||||
try:
|
||||
from .async_abc import AsyncHTTPSender, AsyncClientResponse # pylint: disable=unused-import
|
||||
from .async_abc import __all__ as _async_all
|
||||
__all__ += _async_all
|
||||
except SyntaxError: # Python 2
|
||||
pass
|
|
@ -0,0 +1,100 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
from typing import Any, Callable, AsyncIterator, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import AsyncHTTPSender, ClientRequest, AsyncClientResponse
|
||||
|
||||
# Matching requests, because why not?
|
||||
CONTENT_CHUNK_SIZE = 10 * 1024
|
||||
|
||||
|
||||
class AioHTTPSender(AsyncHTTPSender):
|
||||
"""AioHttp HTTP sender implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, *, loop=None):
|
||||
self._session = aiohttp.ClientSession(loop=loop)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._session.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
await self._session.__aexit__(*exc_details)
|
||||
|
||||
async def send(self, request: ClientRequest, **config: Any) -> AsyncClientResponse:
|
||||
"""Send the request using this HTTP sender.
|
||||
|
||||
Will pre-load the body into memory to be available with a sync method.
|
||||
pass stream=True to avoid this behavior.
|
||||
"""
|
||||
result = await self._session.request(
|
||||
request.method,
|
||||
request.url,
|
||||
**config
|
||||
)
|
||||
response = AioHttpClientResponse(request, result)
|
||||
if not config.get("stream", False):
|
||||
await response.load_body()
|
||||
return response
|
||||
|
||||
|
||||
class AioHttpClientResponse(AsyncClientResponse):
|
||||
def __init__(self, request: ClientRequest, aiohttp_response: aiohttp.ClientResponse) -> None:
|
||||
super(AioHttpClientResponse, self).__init__(request, aiohttp_response)
|
||||
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
|
||||
self.status_code = aiohttp_response.status
|
||||
self.headers = aiohttp_response.headers
|
||||
self.reason = aiohttp_response.reason
|
||||
self._body = None
|
||||
|
||||
def body(self) -> bytes:
|
||||
"""Return the whole body as bytes in memory.
|
||||
"""
|
||||
if not self._body:
|
||||
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
|
||||
return self._body
|
||||
|
||||
async def load_body(self) -> None:
|
||||
"""Load in memory the body, so it could be accessible from sync methods."""
|
||||
self._body = await self.internal_response.read()
|
||||
|
||||
def raise_for_status(self):
|
||||
self.internal_response.raise_for_status()
|
||||
|
||||
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIterator[bytes]:
|
||||
"""Generator for streaming request body data.
|
||||
"""
|
||||
chunk_size = chunk_size or CONTENT_CHUNK_SIZE
|
||||
async def async_gen(resp):
|
||||
while True:
|
||||
chunk = await resp.content.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
callback(chunk, resp)
|
||||
return async_gen(self.internal_response)
|
|
@ -0,0 +1,90 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
from typing import Any, List, Union, Callable, AsyncIterator, Optional
|
||||
|
||||
try:
|
||||
from contextlib import AbstractAsyncContextManager # type: ignore
|
||||
except ImportError: # Python <= 3.7
|
||||
class AbstractAsyncContextManager(object): # type: ignore
|
||||
async def __aenter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
@abc.abstractmethod
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
from . import ClientRequest, HTTPClientResponse
|
||||
|
||||
|
||||
class AsyncClientResponse(HTTPClientResponse):
|
||||
|
||||
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIterator[bytes]:
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
Should be implemented by sub-classes if streaming download
|
||||
is supported.
|
||||
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
:param int chunk_size:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AsyncHTTPSender(AbstractAsyncContextManager, abc.ABC):
|
||||
"""An http sender ABC.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, request: ClientRequest, **config: Any) -> AsyncClientResponse:
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_context(self) -> Any:
|
||||
"""Allow the sender to build a context that will be passed
|
||||
across the pipeline with the request.
|
||||
|
||||
Return type has no constraints. Implementation is not
|
||||
required and None by default.
|
||||
"""
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
raise TypeError("Use 'async with' instead")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# __exit__ should exist in pair with __enter__ but never executed
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AsyncHTTPSender',
|
||||
'AsyncClientResponse'
|
||||
]
|
|
@ -0,0 +1,240 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, AsyncIterator as AsyncIteratorType
|
||||
|
||||
from oauthlib import oauth2
|
||||
import requests
|
||||
from requests.models import CONTENT_CHUNK_SIZE
|
||||
|
||||
from ..exceptions import (
|
||||
TokenExpiredError,
|
||||
ClientRequestError,
|
||||
raise_with_traceback)
|
||||
from . import AsyncHTTPSender, ClientRequest, AsyncClientResponse
|
||||
from .requests import (
|
||||
BasicRequestsHTTPSender,
|
||||
RequestsHTTPSender,
|
||||
HTTPRequestsClientResponse
|
||||
)
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncBasicRequestsHTTPSender(BasicRequestsHTTPSender, AsyncHTTPSender): # type: ignore
|
||||
|
||||
async def __aenter__(self):
|
||||
return super(AsyncBasicRequestsHTTPSender, self).__enter__()
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
return super(AsyncBasicRequestsHTTPSender, self).__exit__()
|
||||
|
||||
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
# It's not recommended to provide its own session, and is mostly
|
||||
# to enable some legacy code to plug correctly
|
||||
session = kwargs.pop('session', self.session)
|
||||
|
||||
loop = kwargs.get("loop", asyncio.get_event_loop())
|
||||
future = loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
session.request,
|
||||
request.method,
|
||||
request.url,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
try:
|
||||
return AsyncRequestsClientResponse(
|
||||
request,
|
||||
await future
|
||||
)
|
||||
except requests.RequestException as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
|
||||
class AsyncRequestsHTTPSender(AsyncBasicRequestsHTTPSender, RequestsHTTPSender): # type: ignore
|
||||
|
||||
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
requests_kwargs = self._configure_send(request, **kwargs)
|
||||
return await super(AsyncRequestsHTTPSender, self).send(request, **requests_kwargs)
|
||||
|
||||
|
||||
class _MsrestStopIteration(Exception):
|
||||
pass
|
||||
|
||||
def _msrest_next(iterator):
|
||||
""""To avoid:
|
||||
TypeError: StopIteration interacts badly with generators and cannot be raised into a Future
|
||||
"""
|
||||
try:
|
||||
return next(iterator)
|
||||
except StopIteration:
|
||||
raise _MsrestStopIteration()
|
||||
|
||||
class StreamDownloadGenerator(AsyncIterator):
|
||||
|
||||
def __init__(self, response: requests.Response, user_callback: Optional[Callable] = None, block: Optional[int] = None) -> None:
|
||||
self.response = response
|
||||
self.block = block or CONTENT_CHUNK_SIZE
|
||||
self.user_callback = user_callback
|
||||
self.iter_content_func = self.response.iter_content(self.block)
|
||||
|
||||
async def __anext__(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
chunk = await loop.run_in_executor(
|
||||
None,
|
||||
_msrest_next,
|
||||
self.iter_content_func,
|
||||
)
|
||||
if not chunk:
|
||||
raise _MsrestStopIteration()
|
||||
if self.user_callback and callable(self.user_callback):
|
||||
self.user_callback(chunk, self.response)
|
||||
return chunk
|
||||
except _MsrestStopIteration:
|
||||
self.response.close()
|
||||
raise StopAsyncIteration()
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Unable to stream download: %s", err)
|
||||
self.response.close()
|
||||
raise
|
||||
|
||||
class AsyncRequestsClientResponse(AsyncClientResponse, HTTPRequestsClientResponse):
|
||||
|
||||
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIteratorType[bytes]:
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
:param int chunk_size:
|
||||
"""
|
||||
return StreamDownloadGenerator(
|
||||
self.internal_response,
|
||||
callback,
|
||||
chunk_size
|
||||
)
|
||||
|
||||
|
||||
# Trio support
|
||||
try:
|
||||
import trio
|
||||
|
||||
class TrioStreamDownloadGenerator(AsyncIterator):
|
||||
|
||||
def __init__(self, response: requests.Response, user_callback: Optional[Callable] = None, block: Optional[int] = None) -> None:
|
||||
self.response = response
|
||||
self.block = block or CONTENT_CHUNK_SIZE
|
||||
self.user_callback = user_callback
|
||||
self.iter_content_func = self.response.iter_content(self.block)
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await trio.run_sync_in_worker_thread(
|
||||
_msrest_next,
|
||||
self.iter_content_func,
|
||||
)
|
||||
if not chunk:
|
||||
raise _MsrestStopIteration()
|
||||
if self.user_callback and callable(self.user_callback):
|
||||
self.user_callback(chunk, self.response)
|
||||
return chunk
|
||||
except _MsrestStopIteration:
|
||||
self.response.close()
|
||||
raise StopAsyncIteration()
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Unable to stream download: %s", err)
|
||||
self.response.close()
|
||||
raise
|
||||
|
||||
class TrioAsyncRequestsClientResponse(AsyncClientResponse, HTTPRequestsClientResponse):
|
||||
|
||||
def stream_download(self, chunk_size: Optional[int] = None, callback: Optional[Callable] = None) -> AsyncIteratorType[bytes]:
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
:param int chunk_size:
|
||||
"""
|
||||
return TrioStreamDownloadGenerator(
|
||||
self.internal_response,
|
||||
callback,
|
||||
chunk_size
|
||||
)
|
||||
|
||||
|
||||
class AsyncTrioBasicRequestsHTTPSender(BasicRequestsHTTPSender, AsyncHTTPSender): # type: ignore
|
||||
|
||||
async def __aenter__(self):
|
||||
return super(AsyncTrioBasicRequestsHTTPSender, self).__enter__()
|
||||
|
||||
async def __aexit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
return super(AsyncTrioBasicRequestsHTTPSender, self).__exit__()
|
||||
|
||||
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
# It's not recommended to provide its own session, and is mostly
|
||||
# to enable some legacy code to plug correctly
|
||||
session = kwargs.pop('session', self.session)
|
||||
|
||||
trio_limiter = kwargs.get("trio_limiter", None)
|
||||
future = trio.run_sync_in_worker_thread(
|
||||
functools.partial(
|
||||
session.request,
|
||||
request.method,
|
||||
request.url,
|
||||
**kwargs
|
||||
),
|
||||
limiter=trio_limiter
|
||||
)
|
||||
try:
|
||||
return TrioAsyncRequestsClientResponse(
|
||||
request,
|
||||
await future
|
||||
)
|
||||
except requests.RequestException as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
|
||||
class AsyncTrioRequestsHTTPSender(AsyncTrioBasicRequestsHTTPSender, RequestsHTTPSender): # type: ignore
|
||||
|
||||
async def send(self, request: ClientRequest, **kwargs: Any) -> AsyncClientResponse: # type: ignore
|
||||
"""Send the request using this HTTP sender.
|
||||
"""
|
||||
requests_kwargs = self._configure_send(request, **kwargs)
|
||||
return await super(AsyncTrioRequestsHTTPSender, self).send(request, **requests_kwargs)
|
||||
|
||||
except ImportError:
|
||||
# trio not installed
|
||||
pass
|
|
@ -0,0 +1,459 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
"""
|
||||
This module is the requests implementation of Pipeline ABC
|
||||
"""
|
||||
from __future__ import absolute_import # we have a "requests" module that conflicts with "requests" on Py2.7
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Callable, Iterator, Any, Union, Dict, Optional # pylint: disable=unused-import
|
||||
import warnings
|
||||
|
||||
from oauthlib import oauth2
|
||||
import requests
|
||||
from requests.models import CONTENT_CHUNK_SIZE
|
||||
|
||||
from urllib3 import Retry # Needs requests 2.16 at least to be safe
|
||||
|
||||
from ..exceptions import (
|
||||
TokenExpiredError,
|
||||
ClientRequestError,
|
||||
raise_with_traceback)
|
||||
from . import HTTPSender, HTTPClientResponse, ClientResponse, HTTPSenderConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import ClientRequest # pylint: disable=unused-import
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class HTTPRequestsClientResponse(HTTPClientResponse):
|
||||
def __init__(self, request, requests_response):
|
||||
super(HTTPRequestsClientResponse, self).__init__(request, requests_response)
|
||||
self.status_code = requests_response.status_code
|
||||
self.headers = requests_response.headers
|
||||
self.reason = requests_response.reason
|
||||
|
||||
def body(self):
|
||||
return self.internal_response.content
|
||||
|
||||
def text(self, encoding=None):
|
||||
if encoding:
|
||||
self.internal_response.encoding = encoding
|
||||
return self.internal_response.text
|
||||
|
||||
def raise_for_status(self):
|
||||
self.internal_response.raise_for_status()
|
||||
|
||||
class RequestsClientResponse(HTTPRequestsClientResponse, ClientResponse):
|
||||
|
||||
def stream_download(self, chunk_size=None, callback=None):
|
||||
# type: (Optional[int], Optional[Callable]) -> Iterator[bytes]
|
||||
"""Generator for streaming request body data.
|
||||
|
||||
:param callback: Custom callback for monitoring progress.
|
||||
:param int chunk_size:
|
||||
"""
|
||||
chunk_size = chunk_size or CONTENT_CHUNK_SIZE
|
||||
with contextlib.closing(self.internal_response) as response:
|
||||
# https://github.com/PyCQA/pylint/issues/1437
|
||||
for chunk in response.iter_content(chunk_size): # pylint: disable=no-member
|
||||
if not chunk:
|
||||
break
|
||||
if callback and callable(callback):
|
||||
callback(chunk, response=response)
|
||||
yield chunk
|
||||
|
||||
|
||||
class BasicRequestsHTTPSender(HTTPSender):
|
||||
"""Implements a basic requests HTTP sender.
|
||||
|
||||
Since requests team recommends to use one session per requests, you should
|
||||
not consider this class as thread-safe, since it will use one Session
|
||||
per instance.
|
||||
|
||||
In this simple implementation:
|
||||
- You provide the configured session if you want to, or a basic session is created.
|
||||
- All kwargs received by "send" are sent to session.request directly
|
||||
"""
|
||||
|
||||
def __init__(self, session=None):
|
||||
# type: (Optional[requests.Session]) -> None
|
||||
self.session = session or requests.Session()
|
||||
|
||||
def __enter__(self):
|
||||
# type: () -> BasicRequestsHTTPSender
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_details): # pylint: disable=arguments-differ
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
# type: (ClientRequest, Any) -> ClientResponse
|
||||
"""Send request object according to configuration.
|
||||
|
||||
Allowed kwargs are:
|
||||
- session : will override the driver session and use yours. Should NOT be done unless really required.
|
||||
- anything else is sent straight to requests.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
"""
|
||||
# It's not recommended to provide its own session, and is mostly
|
||||
# to enable some legacy code to plug correctly
|
||||
session = kwargs.pop('session', self.session)
|
||||
try:
|
||||
response = session.request(
|
||||
request.method,
|
||||
request.url,
|
||||
**kwargs)
|
||||
except requests.RequestException as err:
|
||||
msg = "Error occurred in request."
|
||||
raise_with_traceback(ClientRequestError, msg, err)
|
||||
|
||||
return RequestsClientResponse(request, response)
|
||||
|
||||
|
||||
def _patch_redirect(session):
|
||||
# type: (requests.Session) -> None
|
||||
"""Whether redirect policy should be applied based on status code.
|
||||
|
||||
HTTP spec says that on 301/302 not HEAD/GET, should NOT redirect.
|
||||
But requests does, to follow browser more than spec
|
||||
https://github.com/requests/requests/blob/f6e13ccfc4b50dc458ee374e5dba347205b9a2da/requests/sessions.py#L305-L314
|
||||
|
||||
This patches "requests" to be more HTTP compliant.
|
||||
|
||||
Note that this is super dangerous, since technically this is not public API.
|
||||
"""
|
||||
def enforce_http_spec(resp, request):
|
||||
if resp.status_code in (301, 302) and \
|
||||
request.method not in ['GET', 'HEAD']:
|
||||
return False
|
||||
return True
|
||||
|
||||
redirect_logic = session.resolve_redirects
|
||||
|
||||
def wrapped_redirect(resp, req, **kwargs):
|
||||
attempt = enforce_http_spec(resp, req)
|
||||
return redirect_logic(resp, req, **kwargs) if attempt else []
|
||||
wrapped_redirect.is_msrest_patched = True # type: ignore
|
||||
|
||||
session.resolve_redirects = wrapped_redirect # type: ignore
|
||||
|
||||
class RequestsHTTPSender(BasicRequestsHTTPSender):
|
||||
"""A requests HTTP sender that can consume a msrest.Configuration object.
|
||||
|
||||
This instance will consume the following configuration attributes:
|
||||
- connection
|
||||
- proxies
|
||||
- retry_policy
|
||||
- redirect_policy
|
||||
- enable_http_logger
|
||||
- hooks
|
||||
- session_configuration_callback
|
||||
"""
|
||||
|
||||
_protocols = ['http://', 'https://']
|
||||
|
||||
# Set of authorized kwargs at the operation level
|
||||
_REQUESTS_KWARGS = [
|
||||
'cookies',
|
||||
'verify',
|
||||
'timeout',
|
||||
'allow_redirects',
|
||||
'proxies',
|
||||
'verify',
|
||||
'cert'
|
||||
]
|
||||
|
||||
def __init__(self, config=None):
|
||||
# type: (Optional[RequestHTTPSenderConfiguration]) -> None
|
||||
self._session_mapping = threading.local()
|
||||
self.config = config or RequestHTTPSenderConfiguration()
|
||||
super(RequestsHTTPSender, self).__init__()
|
||||
|
||||
@property # type: ignore
|
||||
def session(self):
|
||||
try:
|
||||
return self._session_mapping.session
|
||||
except AttributeError:
|
||||
self._session_mapping.session = requests.Session()
|
||||
self._init_session(self._session_mapping.session)
|
||||
return self._session_mapping.session
|
||||
|
||||
@session.setter
|
||||
def session(self, value):
|
||||
self._init_session(value)
|
||||
self._session_mapping.session = value
|
||||
|
||||
def _init_session(self, session):
|
||||
# type: (requests.Session) -> None
|
||||
"""Init session level configuration of requests.
|
||||
|
||||
This is initialization I want to do once only on a session.
|
||||
"""
|
||||
_patch_redirect(session)
|
||||
|
||||
# Change max_retries in current all installed adapters
|
||||
max_retries = self.config.retry_policy()
|
||||
for protocol in self._protocols:
|
||||
session.adapters[protocol].max_retries = max_retries
|
||||
|
||||
def _configure_send(self, request, **kwargs):
|
||||
# type: (ClientRequest, Any) -> Dict[str, str]
|
||||
"""Configure the kwargs to use with requests.
|
||||
|
||||
See "send" for kwargs details.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
:returns: The requests.Session.request kwargs
|
||||
:rtype: dict[str,str]
|
||||
"""
|
||||
requests_kwargs = {} # type: Any
|
||||
session = kwargs.pop('session', self.session)
|
||||
|
||||
# If custom session was not create here
|
||||
if session is not self.session:
|
||||
self._init_session(session)
|
||||
|
||||
session.max_redirects = int(self.config.redirect_policy())
|
||||
session.trust_env = bool(self.config.proxies.use_env_settings)
|
||||
|
||||
# Initialize requests_kwargs with "config" value
|
||||
requests_kwargs.update(self.config.connection())
|
||||
requests_kwargs['allow_redirects'] = bool(self.config.redirect_policy)
|
||||
requests_kwargs['headers'] = self.config.headers.copy()
|
||||
|
||||
proxies = self.config.proxies()
|
||||
if proxies:
|
||||
requests_kwargs['proxies'] = proxies
|
||||
|
||||
# Replace by operation level kwargs
|
||||
# We allow some of them, since some like stream or json are controled by msrest
|
||||
for key in kwargs:
|
||||
if key in self._REQUESTS_KWARGS:
|
||||
requests_kwargs[key] = kwargs[key]
|
||||
|
||||
# Hooks. Deprecated, should be a policy
|
||||
def make_user_hook_cb(user_hook, session):
|
||||
def user_hook_cb(r, *args, **kwargs):
|
||||
kwargs.setdefault("msrest", {})['session'] = session
|
||||
return user_hook(r, *args, **kwargs)
|
||||
return user_hook_cb
|
||||
|
||||
hooks = []
|
||||
for user_hook in self.config.hooks:
|
||||
hooks.append(make_user_hook_cb(user_hook, self.session))
|
||||
|
||||
if hooks:
|
||||
requests_kwargs['hooks'] = {'response': hooks}
|
||||
|
||||
# Configuration callback. Deprecated, should be a policy
|
||||
output_kwargs = self.config.session_configuration_callback(
|
||||
session,
|
||||
self.config,
|
||||
kwargs,
|
||||
**requests_kwargs
|
||||
)
|
||||
if output_kwargs is not None:
|
||||
requests_kwargs = output_kwargs
|
||||
|
||||
# If custom session was not create here
|
||||
if session is not self.session:
|
||||
requests_kwargs['session'] = session
|
||||
|
||||
### Autorest forced kwargs now ###
|
||||
|
||||
# If Autorest needs this response to be streamable. True for compat.
|
||||
requests_kwargs['stream'] = kwargs.get('stream', True)
|
||||
|
||||
if request.files:
|
||||
requests_kwargs['files'] = request.files
|
||||
elif request.data:
|
||||
requests_kwargs['data'] = request.data
|
||||
requests_kwargs['headers'].update(request.headers)
|
||||
|
||||
return requests_kwargs
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
# type: (ClientRequest, Any) -> ClientResponse
|
||||
"""Send request object according to configuration.
|
||||
|
||||
Available kwargs:
|
||||
- session : will override the driver session and use yours. Should NOT be done unless really required.
|
||||
- A subset of what requests.Session.request can receive:
|
||||
|
||||
- cookies
|
||||
- verify
|
||||
- timeout
|
||||
- allow_redirects
|
||||
- proxies
|
||||
- verify
|
||||
- cert
|
||||
|
||||
Everything else will be silently ignored.
|
||||
|
||||
:param ClientRequest request: The request object to be sent.
|
||||
"""
|
||||
requests_kwargs = self._configure_send(request, **kwargs)
|
||||
return super(RequestsHTTPSender, self).send(request, **requests_kwargs)
|
||||
|
||||
|
||||
class ClientRetryPolicy(object):
|
||||
"""Retry configuration settings.
|
||||
Container for retry policy object.
|
||||
"""
|
||||
|
||||
safe_codes = [i for i in range(500) if i != 408] + [501, 505]
|
||||
|
||||
def __init__(self):
|
||||
self.policy = Retry()
|
||||
self.policy.total = 3
|
||||
self.policy.connect = 3
|
||||
self.policy.read = 3
|
||||
self.policy.backoff_factor = 0.8
|
||||
self.policy.BACKOFF_MAX = 90
|
||||
|
||||
retry_codes = [i for i in range(999) if i not in self.safe_codes]
|
||||
self.policy.status_forcelist = retry_codes
|
||||
self.policy.method_whitelist = ['HEAD', 'TRACE', 'GET', 'PUT',
|
||||
'OPTIONS', 'DELETE', 'POST', 'PATCH']
|
||||
|
||||
def __call__(self):
|
||||
# type: () -> Retry
|
||||
"""Return configuration to be applied to connection."""
|
||||
debug = ("Configuring retry: max_retries=%r, "
|
||||
"backoff_factor=%r, max_backoff=%r")
|
||||
_LOGGER.debug(
|
||||
debug, self.retries, self.backoff_factor, self.max_backoff)
|
||||
return self.policy
|
||||
|
||||
@property
|
||||
def retries(self):
|
||||
# type: () -> int
|
||||
"""Total number of allowed retries."""
|
||||
return self.policy.total
|
||||
|
||||
@retries.setter
|
||||
def retries(self, value):
|
||||
# type: (int) -> None
|
||||
self.policy.total = value
|
||||
self.policy.connect = value
|
||||
self.policy.read = value
|
||||
|
||||
@property
|
||||
def backoff_factor(self):
|
||||
# type: () -> Union[int, float]
|
||||
"""Factor by which back-off delay is incementally increased."""
|
||||
return self.policy.backoff_factor
|
||||
|
||||
@backoff_factor.setter
|
||||
def backoff_factor(self, value):
|
||||
# type: (Union[int, float]) -> None
|
||||
self.policy.backoff_factor = value
|
||||
|
||||
@property
|
||||
def max_backoff(self):
|
||||
# type: () -> int
|
||||
"""Max retry back-off delay."""
|
||||
return self.policy.BACKOFF_MAX
|
||||
|
||||
@max_backoff.setter
|
||||
def max_backoff(self, value):
|
||||
# type: (int) -> None
|
||||
self.policy.BACKOFF_MAX = value
|
||||
|
||||
def default_session_configuration_callback(session, global_config, local_config, **kwargs): # pylint: disable=unused-argument
|
||||
# type: (requests.Session, RequestHTTPSenderConfiguration, Dict[str,str], str) -> Dict[str, str]
|
||||
"""Configuration callback if you need to change default session configuration.
|
||||
|
||||
:param requests.Session session: The session.
|
||||
:param Configuration global_config: The global configuration.
|
||||
:param dict[str,str] local_config: The on-the-fly configuration passed on the call.
|
||||
:param dict[str,str] kwargs: The current computed values for session.request method.
|
||||
:return: Must return kwargs, to be passed to session.request. If None is return, initial kwargs will be used.
|
||||
:rtype: dict[str,str]
|
||||
"""
|
||||
return kwargs
|
||||
|
||||
class RequestHTTPSenderConfiguration(HTTPSenderConfiguration):
|
||||
"""Requests specific HTTP sender configuration.
|
||||
|
||||
:param str filepath: Path to existing config file (optional).
|
||||
"""
|
||||
|
||||
def __init__(self, filepath=None):
|
||||
# type: (Optional[str]) -> None
|
||||
|
||||
super(RequestHTTPSenderConfiguration, self).__init__()
|
||||
|
||||
# Retry configuration
|
||||
self.retry_policy = ClientRetryPolicy()
|
||||
|
||||
# Requests hooks. Must respect requests hook callback signature
|
||||
# Note that we will inject the following parameters:
|
||||
# - kwargs['msrest']['session'] with the current session
|
||||
self.hooks = [] # type: List[Callable[[requests.Response, str, str], None]]
|
||||
|
||||
self.session_configuration_callback = default_session_configuration_callback
|
||||
|
||||
if filepath:
|
||||
self.load(filepath)
|
||||
|
||||
def save(self, filepath):
|
||||
"""Save current configuration to file.
|
||||
|
||||
:param str filepath: Path to file where settings will be saved.
|
||||
:raises: ValueError if supplied filepath cannot be written to.
|
||||
"""
|
||||
self._config.add_section("RetryPolicy")
|
||||
self._config.set("RetryPolicy", "retries", str(self.retry_policy.retries))
|
||||
self._config.set("RetryPolicy", "backoff_factor",
|
||||
str(self.retry_policy.backoff_factor))
|
||||
self._config.set("RetryPolicy", "max_backoff",
|
||||
str(self.retry_policy.max_backoff))
|
||||
super(RequestHTTPSenderConfiguration, self).save(filepath)
|
||||
|
||||
def load(self, filepath):
|
||||
try:
|
||||
self.retry_policy.retries = \
|
||||
self._config.getint("RetryPolicy", "retries")
|
||||
self.retry_policy.backoff_factor = \
|
||||
self._config.getfloat("RetryPolicy", "backoff_factor")
|
||||
self.retry_policy.max_backoff = \
|
||||
self._config.getint("RetryPolicy", "max_backoff")
|
||||
except (ValueError, EnvironmentError, NoOptionError):
|
||||
error = "Supplied config file incompatible."
|
||||
raise_with_traceback(ValueError, error)
|
||||
finally:
|
||||
self._clear_config()
|
||||
super(RequestHTTPSenderConfiguration, self).load(filepath)
|
|
@ -25,4 +25,4 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
#: version of this package. Use msrest.__version__ instead
|
||||
msrest_version = "0.5.5"
|
||||
msrest_version = "0.6.0rc1"
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
[MASTER]
|
||||
ignore-patterns=test_*
|
||||
reports=no
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
# For all codes, run 'pylint --list-msgs' or go to 'https://pylint.readthedocs.io/en/latest/reference_guide/features.html'
|
||||
# locally-disabled: Warning locally suppressed using disable-msg
|
||||
# cyclic-import: because of https://github.com/PyCQA/pylint/issues/850
|
||||
# too-many-arguments: Due to the nature of the CLI many commands have large arguments set which reflect in large arguments set in corresponding methods.
|
||||
disable=missing-docstring,locally-disabled,fixme,cyclic-import,too-many-arguments,invalid-name,duplicate-code
|
||||
|
||||
[FORMAT]
|
||||
max-line-length=120
|
||||
|
||||
[VARIABLES]
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=yes
|
||||
|
||||
[DESIGN]
|
||||
# Maximum number of locals for function / method body
|
||||
max-locals=25
|
||||
# Maximum number of branch for function / method body
|
||||
max-branches=20
|
||||
|
||||
[SIMILARITIES]
|
||||
min-similarity-lines=10
|
||||
|
||||
[BASIC]
|
||||
# Naming hints based on PEP 8 (https://www.python.org/dev/peps/pep-0008/#naming-conventions).
|
||||
# Consider these guidelines and not hard rules. Read PEP 8 for more details.
|
||||
|
||||
# The invalid-name checker must be **enabled** for these hints to be used.
|
||||
include-naming-hint=yes
|
||||
|
||||
module-name-hint=lowercase (keep short; underscores are discouraged)
|
||||
const-name-hint=UPPER_CASE_WITH_UNDERSCORES
|
||||
class-name-hint=CapitalizedWords
|
||||
class-attribute-name-hint=lower_case_with_underscores
|
||||
attr-name-hint=lower_case_with_underscores
|
||||
method-name-hint=lower_case_with_underscores
|
||||
function-name-hint=lower_case_with_underscores
|
||||
argument-name-hint=lower_case_with_underscores
|
||||
variable-name-hint=lower_case_with_underscores
|
||||
inlinevar-name-hint=lower_case_with_underscores (short is OK)
|
|
@ -2,4 +2,7 @@
|
|||
universal=1
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports = True
|
||||
ignore_missing_imports = True
|
||||
|
||||
[tool:pytest]
|
||||
addopts = --durations=10
|
6
setup.py
6
setup.py
|
@ -28,7 +28,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name='msrest',
|
||||
version='0.5.5',
|
||||
version='0.6.0rc1',
|
||||
author='Microsoft Corporation',
|
||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||
url=("https://github.com/Azure/msrest-for-python"),
|
||||
|
@ -56,5 +56,9 @@ setup(
|
|||
extras_require={
|
||||
":python_version<'3.4'": ['enum34>=1.0.4'],
|
||||
":python_version<'3.5'": ['typing'],
|
||||
"async:python_version>='3.5'": [
|
||||
'aiohttp>=3.0',
|
||||
'aiodns'
|
||||
],
|
||||
}
|
||||
)
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
import io
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from oauthlib import oauth2
|
||||
|
||||
from msrest import ServiceClient
|
||||
from msrest.authentication import OAuthTokenAuthentication
|
||||
from msrest.configuration import Configuration
|
||||
|
||||
from msrest import Configuration
|
||||
from msrest.exceptions import ClientRequestError, TokenExpiredError
|
||||
from msrest.universal_http import ClientRequest
|
||||
from msrest.universal_http.async_requests import AsyncRequestsClientResponse
|
||||
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 5, 2), "Async tests only on 3.5.2 minimal")
|
||||
class TestServiceClient(object):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send(self):
|
||||
|
||||
cfg = Configuration("/")
|
||||
cfg.headers = {'Test': 'true'}
|
||||
creds = mock.create_autospec(OAuthTokenAuthentication)
|
||||
|
||||
client = ServiceClient(creds, cfg)
|
||||
|
||||
req_response = requests.Response()
|
||||
req_response._content = br'{"real": true}' # Has to be valid bytes JSON
|
||||
req_response._content_consumed = True
|
||||
req_response.status_code = 200
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
return req_response
|
||||
|
||||
session = mock.create_autospec(requests.Session)
|
||||
session.request.side_effect = side_effect
|
||||
session.adapters = {
|
||||
"http://": HTTPAdapter(),
|
||||
"https://": HTTPAdapter(),
|
||||
}
|
||||
# Be sure the mock does not trick me
|
||||
assert not hasattr(session.resolve_redirects, 'is_msrest_patched')
|
||||
|
||||
client.config.async_pipeline._sender.driver.session = session
|
||||
|
||||
client._creds.signed_session.return_value = session
|
||||
client._creds.refresh_session.return_value = session
|
||||
|
||||
request = ClientRequest('GET', '/')
|
||||
await client.async_send(request, stream=False)
|
||||
session.request.call_count = 0
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
'/',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': cfg.user_agent,
|
||||
'Test': 'true' # From global config
|
||||
},
|
||||
stream=False,
|
||||
timeout=100,
|
||||
verify=True
|
||||
)
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
request = client.get('/', headers={'id':'1234'}, content={'Test':'Data'})
|
||||
await client.async_send(request, stream=False)
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
'/',
|
||||
data='{"Test": "Data"}',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': cfg.user_agent,
|
||||
'Content-Length': '16',
|
||||
'id':'1234',
|
||||
'Accept': 'application/json',
|
||||
'Test': 'true' # From global config
|
||||
},
|
||||
stream=False,
|
||||
timeout=100,
|
||||
verify=True
|
||||
)
|
||||
assert session.request.call_count == 1
|
||||
session.request.call_count = 0
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
request = client.get('/', headers={'id':'1234'}, content={'Test':'Data'})
|
||||
session.request.side_effect = requests.RequestException("test")
|
||||
with pytest.raises(ClientRequestError):
|
||||
await client.async_send(request, test='value', stream=False)
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
'/',
|
||||
data='{"Test": "Data"}',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': cfg.user_agent,
|
||||
'Content-Length': '16',
|
||||
'id':'1234',
|
||||
'Accept': 'application/json',
|
||||
'Test': 'true' # From global config
|
||||
},
|
||||
stream=False,
|
||||
timeout=100,
|
||||
verify=True
|
||||
)
|
||||
assert session.request.call_count == 1
|
||||
session.request.call_count = 0
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
session.request.side_effect = oauth2.rfc6749.errors.InvalidGrantError("test")
|
||||
with pytest.raises(TokenExpiredError):
|
||||
await client.async_send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
|
||||
assert session.request.call_count == 2
|
||||
session.request.call_count = 0
|
||||
|
||||
session.request.side_effect = ValueError("test")
|
||||
with pytest.raises(ValueError):
|
||||
await client.async_send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_stream_download(self):
|
||||
|
||||
req_response = requests.Response()
|
||||
req_response._content = "abc"
|
||||
req_response._content_consumed = True
|
||||
req_response.status_code = 200
|
||||
|
||||
client_response = AsyncRequestsClientResponse(
|
||||
None,
|
||||
req_response
|
||||
)
|
||||
|
||||
def user_callback(chunk, response):
|
||||
assert response is req_response
|
||||
assert chunk in ["a", "b", "c"]
|
||||
|
||||
async_iterator = client_response.stream_download(1, user_callback)
|
||||
result = ""
|
||||
async for value in async_iterator:
|
||||
result += value
|
||||
assert result == "abc"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,173 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
import sys
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from msrest.paging import Paged
|
||||
|
||||
class FakePaged(Paged):
|
||||
_attribute_map = {
|
||||
'next_link': {'key': 'nextLink', 'type': 'str'},
|
||||
'current_page': {'key': 'value', 'type': '[str]'}
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FakePaged, self).__init__(*args, **kwargs)
|
||||
|
||||
class TestPaging(object):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_paging(self):
|
||||
|
||||
async def internal_paging(next_link=None, raw=False):
|
||||
if not next_link:
|
||||
return {
|
||||
'nextLink': 'page2',
|
||||
'value': ['value1.0', 'value1.1']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'nextLink': None,
|
||||
'value': ['value2.0', 'value2.1']
|
||||
}
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
|
||||
# 3.6 only : result_iterated = [obj async for obj in deserialized]
|
||||
result_iterated = []
|
||||
async for obj in deserialized:
|
||||
result_iterated.append(obj)
|
||||
|
||||
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advance_paging(self):
|
||||
|
||||
async def internal_paging(next_link=None, raw=False):
|
||||
if not next_link:
|
||||
return {
|
||||
'nextLink': 'page2',
|
||||
'value': ['value1.0', 'value1.1']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'nextLink': None,
|
||||
'value': ['value2.0', 'value2.1']
|
||||
}
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
page1 = await deserialized.async_advance_page()
|
||||
assert ['value1.0', 'value1.1'] == page1
|
||||
|
||||
page2 = await deserialized.async_advance_page()
|
||||
assert ['value2.0', 'value2.1'] == page2
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await deserialized.async_advance_page()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paging(self):
|
||||
|
||||
async def internal_paging(next_link=None, raw=False):
|
||||
if not next_link:
|
||||
return {
|
||||
'nextLink': 'page2',
|
||||
'value': ['value1.0', 'value1.1']
|
||||
}
|
||||
elif next_link == 'page2':
|
||||
return {
|
||||
'nextLink': 'page3',
|
||||
'value': ['value2.0', 'value2.1']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'nextLink': None,
|
||||
'value': ['value3.0', 'value3.1']
|
||||
}
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
page2 = await deserialized.async_get('page2')
|
||||
assert ['value2.0', 'value2.1'] == page2
|
||||
|
||||
page3 = await deserialized.async_get('page3')
|
||||
assert ['value3.0', 'value3.1'] == page3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_paging(self):
|
||||
|
||||
async def internal_paging(next_link=None, raw=False):
|
||||
if not next_link:
|
||||
return {
|
||||
'nextLink': 'page2',
|
||||
'value': ['value1.0', 'value1.1']
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'nextLink': None,
|
||||
'value': ['value2.0', 'value2.1']
|
||||
}
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
deserialized.reset()
|
||||
|
||||
# 3.6 only : result_iterated = [obj async for obj in deserialized]
|
||||
result_iterated = []
|
||||
async for obj in deserialized:
|
||||
result_iterated.append(obj)
|
||||
|
||||
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
# Push the iterator to the last element
|
||||
async for element in deserialized:
|
||||
if element == "value2.0":
|
||||
break
|
||||
deserialized.reset()
|
||||
|
||||
# 3.6 only : result_iterated = [obj async for obj in deserialized]
|
||||
result_iterated = []
|
||||
async for obj in deserialized:
|
||||
result_iterated.append(obj)
|
||||
|
||||
assert ['value1.0', 'value1.1', 'value2.0', 'value2.1'] == result_iterated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_value(self):
|
||||
async def internal_paging(next_link=None, raw=False):
|
||||
return {
|
||||
'nextLink': None,
|
||||
'value': None
|
||||
}
|
||||
|
||||
deserialized = FakePaged(None, {}, async_command=internal_paging)
|
||||
|
||||
# 3.6 only : result_iterated = [obj async for obj in deserialized]
|
||||
result_iterated = []
|
||||
async for obj in deserialized:
|
||||
result_iterated.append(obj)
|
||||
|
||||
assert len(result_iterated) == 0
|
|
@ -0,0 +1,128 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
import sys
|
||||
|
||||
from msrest.universal_http import (
|
||||
ClientRequest,
|
||||
)
|
||||
from msrest.universal_http.async_requests import (
|
||||
AsyncRequestsHTTPSender,
|
||||
AsyncTrioRequestsHTTPSender,
|
||||
)
|
||||
|
||||
from msrest.pipeline import (
|
||||
AsyncPipeline,
|
||||
AsyncHTTPSender,
|
||||
SansIOHTTPPolicy
|
||||
)
|
||||
from msrest.pipeline.async_requests import AsyncPipelineRequestsHTTPSender
|
||||
from msrest.pipeline.universal import UserAgentPolicy
|
||||
from msrest.pipeline.aiohttp import AioHTTPSender
|
||||
|
||||
from msrest.configuration import Configuration
|
||||
|
||||
import trio
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sans_io_exception():
|
||||
class BrokenSender(AsyncHTTPSender):
|
||||
async def send(self, request, **config):
|
||||
raise ValueError("Broken")
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
pipeline = AsyncPipeline([SansIOHTTPPolicy()], BrokenSender())
|
||||
|
||||
req = ClientRequest('GET', '/')
|
||||
with pytest.raises(ValueError):
|
||||
await pipeline.run(req)
|
||||
|
||||
class SwapExec(SansIOHTTPPolicy):
|
||||
def on_exception(self, requests, **kwargs):
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
raise NotImplementedError(exc_value)
|
||||
|
||||
pipeline = AsyncPipeline([SwapExec()], BrokenSender())
|
||||
with pytest.raises(NotImplementedError):
|
||||
await pipeline.run(req)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_aiohttp():
|
||||
|
||||
request = ClientRequest("GET", "http://bing.com")
|
||||
policies = [
|
||||
UserAgentPolicy("myusergant")
|
||||
]
|
||||
async with AsyncPipeline(policies) as pipeline:
|
||||
response = await pipeline.run(request)
|
||||
|
||||
assert pipeline._sender.driver._session.closed
|
||||
assert response.http_response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_async_requests():
|
||||
|
||||
request = ClientRequest("GET", "http://bing.com")
|
||||
policies = [
|
||||
UserAgentPolicy("myusergant")
|
||||
]
|
||||
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender()) as pipeline:
|
||||
response = await pipeline.run(request)
|
||||
|
||||
assert response.http_response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conf_async_requests():
|
||||
|
||||
conf = Configuration("http://bing.com/")
|
||||
request = ClientRequest("GET", "http://bing.com/")
|
||||
policies = [
|
||||
UserAgentPolicy("myusergant")
|
||||
]
|
||||
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender(AsyncRequestsHTTPSender(conf))) as pipeline:
|
||||
response = await pipeline.run(request)
|
||||
|
||||
assert response.http_response.status_code == 200
|
||||
|
||||
def test_conf_async_trio_requests():
|
||||
|
||||
async def do():
|
||||
conf = Configuration("http://bing.com/")
|
||||
request = ClientRequest("GET", "http://bing.com/")
|
||||
policies = [
|
||||
UserAgentPolicy("myusergant")
|
||||
]
|
||||
async with AsyncPipeline(policies, AsyncPipelineRequestsHTTPSender(AsyncTrioRequestsHTTPSender(conf))) as pipeline:
|
||||
return await pipeline.run(request)
|
||||
|
||||
response = trio.run(do)
|
||||
assert response.http_response.status_code == 200
|
|
@ -0,0 +1,167 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
import asyncio
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from msrest.polling.async_poller import *
|
||||
from msrest.service_client import ServiceClient
|
||||
from msrest.serialization import Model
|
||||
from msrest.configuration import Configuration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abc_polling():
|
||||
abc_polling = AsyncPollingMethod()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
abc_polling.initialize(None, None, None)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await abc_polling.run()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
abc_polling.status()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
abc_polling.finished()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
abc_polling.resource()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_polling():
|
||||
no_polling = AsyncNoPolling()
|
||||
|
||||
initial_response = "initial response"
|
||||
def deserialization_cb(response):
|
||||
assert response == initial_response
|
||||
return "Treated: "+response
|
||||
|
||||
no_polling.initialize(None, initial_response, deserialization_cb)
|
||||
await no_polling.run() # Should no raise and do nothing
|
||||
assert no_polling.status() == "succeeded"
|
||||
assert no_polling.finished()
|
||||
assert no_polling.resource() == "Treated: "+initial_response
|
||||
|
||||
|
||||
class PollingTwoSteps(AsyncPollingMethod):
|
||||
"""An empty poller that returns the deserialized initial response.
|
||||
"""
|
||||
def __init__(self, sleep=0):
|
||||
self._initial_response = None
|
||||
self._deserialization_callback = None
|
||||
self._sleep = sleep
|
||||
|
||||
def initialize(self, _, initial_response, deserialization_callback):
|
||||
self._initial_response = initial_response
|
||||
self._deserialization_callback = deserialization_callback
|
||||
self._finished = False
|
||||
|
||||
async def run(self):
|
||||
"""Empty run, no polling.
|
||||
"""
|
||||
self._finished = True
|
||||
await asyncio.sleep(self._sleep) # Give me time to add callbacks!
|
||||
|
||||
def status(self):
|
||||
"""Return the current status as a string.
|
||||
:rtype: str
|
||||
"""
|
||||
return "succeeded" if self._finished else "running"
|
||||
|
||||
def finished(self):
|
||||
"""Is this polling finished?
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._finished
|
||||
|
||||
def resource(self):
|
||||
return self._deserialization_callback(self._initial_response)
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# We need a ServiceClient instance, but the poller itself don't use it, so we don't need
|
||||
# Something functionnal
|
||||
return ServiceClient(None, Configuration("http://example.org"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poller(client):
|
||||
|
||||
# Same the poller itself doesn't care about the initial_response, and there is no type constraint here
|
||||
initial_response = "Initial response"
|
||||
|
||||
# Same for deserialization_callback, just pass to the polling_method
|
||||
def deserialization_callback(response):
|
||||
assert response == initial_response
|
||||
return "Treated: "+response
|
||||
|
||||
method = AsyncNoPolling()
|
||||
|
||||
result = await async_poller(client, initial_response, deserialization_callback, method)
|
||||
assert result == "Treated: "+initial_response
|
||||
|
||||
# Test with a basic Model
|
||||
class MockedModel(Model):
|
||||
called = False
|
||||
@classmethod
|
||||
def deserialize(cls, data):
|
||||
assert data == initial_response
|
||||
cls.called = True
|
||||
|
||||
result = await async_poller(client, initial_response, MockedModel, method)
|
||||
assert MockedModel.called
|
||||
|
||||
# Test poller that method do a run
|
||||
method = PollingTwoSteps(sleep=2)
|
||||
result = await async_poller(client, initial_response, deserialization_callback, method)
|
||||
|
||||
assert result == "Treated: "+initial_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broken_poller(client):
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_poller(None, None, None, None)
|
||||
|
||||
class NoPollingError(PollingTwoSteps):
|
||||
async def run(self):
|
||||
raise ValueError("Something bad happened")
|
||||
|
||||
initial_response = "Initial response"
|
||||
def deserialization_callback(response):
|
||||
return "Treated: "+response
|
||||
|
||||
method = NoPollingError()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await async_poller(client, initial_response, deserialization_callback, method)
|
||||
assert "Something bad happened" in str(excinfo.value)
|
|
@ -0,0 +1,88 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
import sys
|
||||
|
||||
from msrest.universal_http import (
|
||||
ClientRequest,
|
||||
AsyncHTTPSender,
|
||||
)
|
||||
from msrest.universal_http.aiohttp import AioHTTPSender
|
||||
from msrest.universal_http.async_requests import (
|
||||
AsyncBasicRequestsHTTPSender,
|
||||
AsyncRequestsHTTPSender,
|
||||
AsyncTrioRequestsHTTPSender,
|
||||
)
|
||||
|
||||
from msrest.configuration import Configuration
|
||||
|
||||
import trio
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_aiohttp():
|
||||
|
||||
request = ClientRequest("GET", "http://bing.com")
|
||||
async with AioHTTPSender() as sender:
|
||||
response = await sender.send(request)
|
||||
assert response.body() is not None
|
||||
|
||||
assert sender._session.closed
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_async_requests():
|
||||
|
||||
request = ClientRequest("GET", "http://bing.com")
|
||||
async with AsyncBasicRequestsHTTPSender() as sender:
|
||||
response = await sender.send(request)
|
||||
assert response.body() is not None
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conf_async_requests():
|
||||
|
||||
conf = Configuration("http://bing.com/")
|
||||
request = ClientRequest("GET", "http://bing.com/")
|
||||
async with AsyncRequestsHTTPSender(conf) as sender:
|
||||
response = await sender.send(request)
|
||||
assert response.body() is not None
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_conf_async_trio_requests():
|
||||
|
||||
async def do():
|
||||
conf = Configuration("http://bing.com/")
|
||||
request = ClientRequest("GET", "http://bing.com/")
|
||||
async with AsyncTrioRequestsHTTPSender(conf) as sender:
|
||||
return await sender.send(request)
|
||||
assert response.body() is not None
|
||||
|
||||
response = trio.run(do)
|
||||
assert response.status_code == 200
|
|
@ -0,0 +1,31 @@
|
|||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to
|
||||
# deal in the Software without restriction, including without limitation the
|
||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
# sell copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
# IN THE SOFTWARE.
|
||||
#
|
||||
# --------------------------------------------------------------------------
|
||||
import sys
|
||||
|
||||
# Ignore collection of async tests for Python 2
|
||||
collect_ignore = []
|
||||
if sys.version_info < (3, 5):
|
||||
collect_ignore.append("asynctests")
|
|
@ -31,18 +31,30 @@ try:
|
|||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
import sys
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from oauthlib import oauth2
|
||||
|
||||
from msrest import ServiceClient, SDKClient
|
||||
from msrest.service_client import _RequestsHTTPDriver
|
||||
from msrest.universal_http import (
|
||||
ClientRequest,
|
||||
ClientResponse
|
||||
)
|
||||
from msrest.universal_http.requests import (
|
||||
RequestsHTTPSender,
|
||||
RequestsClientResponse
|
||||
)
|
||||
|
||||
from msrest.pipeline import (
|
||||
HTTPSender,
|
||||
Response,
|
||||
)
|
||||
from msrest.authentication import OAuthTokenAuthentication, Authentication
|
||||
|
||||
from msrest import Configuration
|
||||
from msrest.exceptions import ClientRequestError, TokenExpiredError
|
||||
from msrest.pipeline import ClientRequest
|
||||
|
||||
|
||||
class TestServiceClient(unittest.TestCase):
|
||||
|
@ -53,19 +65,6 @@ class TestServiceClient(unittest.TestCase):
|
|||
self.creds = mock.create_autospec(OAuthTokenAuthentication)
|
||||
return super(TestServiceClient, self).setUp()
|
||||
|
||||
def test_session_callback(self):
|
||||
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
|
||||
def callback(session, global_config, local_config, **kwargs):
|
||||
self.assertIs(session, driver.session)
|
||||
self.assertIs(global_config, self.cfg)
|
||||
self.assertTrue(local_config["test"])
|
||||
return {'used_callback': True}
|
||||
self.cfg.session_configuration_callback = callback
|
||||
|
||||
output_kwargs = driver.configure_session(**{"test": True})
|
||||
self.assertTrue(output_kwargs['used_callback'])
|
||||
|
||||
def test_sdk_context_manager(self):
|
||||
cfg = Configuration("http://127.0.0.1/")
|
||||
|
@ -87,14 +86,16 @@ class TestServiceClient(unittest.TestCase):
|
|||
with SDKClient(creds, cfg) as client:
|
||||
assert cfg.keep_alive
|
||||
|
||||
req = client._client.get()
|
||||
req = client._client.get('/')
|
||||
try:
|
||||
client._client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client._client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
client._client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client._client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -122,17 +123,18 @@ class TestServiceClient(unittest.TestCase):
|
|||
with ServiceClient(creds, cfg) as client:
|
||||
assert cfg.keep_alive
|
||||
|
||||
req = client.get()
|
||||
req = client.get('/')
|
||||
try:
|
||||
client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
assert client._http_driver.session # Still alive
|
||||
|
||||
assert not cfg.keep_alive
|
||||
assert creds.called == 2
|
||||
|
@ -157,101 +159,60 @@ class TestServiceClient(unittest.TestCase):
|
|||
creds = Creds()
|
||||
|
||||
client = ServiceClient(creds, cfg)
|
||||
req = client.get()
|
||||
req = client.get('/')
|
||||
try:
|
||||
client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
client.send(req) # Will fail, I don't care, that's not the point of the test
|
||||
# Will fail, I don't care, that's not the point of the test
|
||||
client.send(req, timeout=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert creds.called == 2
|
||||
assert client._http_driver.session # Still alive
|
||||
# Manually close the client in "keep_alive" mode
|
||||
client.close()
|
||||
|
||||
def test_max_retries_on_default_adapter(self):
|
||||
# max_retries must be applied only on the default adapters of requests
|
||||
# If the user adds its own adapter, don't touch it
|
||||
max_retries = self.cfg.retry_policy()
|
||||
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
driver.session.mount('http://example.org', HTTPAdapter())
|
||||
driver.configure_session()
|
||||
assert driver.session.adapters["http://"].max_retries is max_retries
|
||||
assert driver.session.adapters["https://"].max_retries is max_retries
|
||||
assert driver.session.adapters['http://example.org'].max_retries is not max_retries
|
||||
|
||||
def test_no_log(self):
|
||||
|
||||
# By default, no log handler for HTTP
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
kwargs = driver.configure_session()
|
||||
assert 'hooks' not in kwargs
|
||||
|
||||
# I can enable it per request
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
kwargs = driver.configure_session(**{"enable_http_logger": True})
|
||||
assert 'hooks' in kwargs
|
||||
|
||||
# I can enable it per request (bool value should be honored)
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
kwargs = driver.configure_session(**{"enable_http_logger": False})
|
||||
assert 'hooks' not in kwargs
|
||||
|
||||
# I can enable it globally
|
||||
self.cfg.enable_http_logger = True
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
kwargs = driver.configure_session()
|
||||
assert 'hooks' in kwargs
|
||||
|
||||
# I can enable it globally and override it locally
|
||||
self.cfg.enable_http_logger = True
|
||||
with _RequestsHTTPDriver(self.cfg) as driver:
|
||||
kwargs = driver.configure_session(**{"enable_http_logger": False})
|
||||
assert 'hooks' not in kwargs
|
||||
|
||||
def test_client_request(self):
|
||||
|
||||
client = ServiceClient(self.creds, self.cfg)
|
||||
obj = client.get()
|
||||
cfg = Configuration("http://127.0.0.1/")
|
||||
client = ServiceClient(self.creds, cfg)
|
||||
obj = client.get('/')
|
||||
self.assertEqual(obj.method, 'GET')
|
||||
self.assertIsNone(obj.url)
|
||||
self.assertEqual(obj.params, {})
|
||||
self.assertEqual(obj.url, "http://127.0.0.1/")
|
||||
|
||||
obj = client.get("/service", {'param':"testing"})
|
||||
self.assertEqual(obj.method, 'GET')
|
||||
self.assertEqual(obj.url, "https://my_endpoint.com/service?param=testing")
|
||||
self.assertEqual(obj.params, {})
|
||||
self.assertEqual(obj.url, "http://127.0.0.1/service?param=testing")
|
||||
|
||||
obj = client.get("service 2")
|
||||
self.assertEqual(obj.method, 'GET')
|
||||
self.assertEqual(obj.url, "https://my_endpoint.com/service 2")
|
||||
self.assertEqual(obj.url, "http://127.0.0.1/service 2")
|
||||
|
||||
self.cfg.base_url = "https://my_endpoint.com/"
|
||||
cfg.base_url = "https://127.0.0.1/"
|
||||
obj = client.get("//service3")
|
||||
self.assertEqual(obj.method, 'GET')
|
||||
self.assertEqual(obj.url, "https://my_endpoint.com/service3")
|
||||
self.assertEqual(obj.url, "https://127.0.0.1/service3")
|
||||
|
||||
obj = client.put()
|
||||
obj = client.put('/')
|
||||
self.assertEqual(obj.method, 'PUT')
|
||||
|
||||
obj = client.post()
|
||||
obj = client.post('/')
|
||||
self.assertEqual(obj.method, 'POST')
|
||||
|
||||
obj = client.head()
|
||||
obj = client.head('/')
|
||||
self.assertEqual(obj.method, 'HEAD')
|
||||
|
||||
obj = client.merge()
|
||||
obj = client.merge('/')
|
||||
self.assertEqual(obj.method, 'MERGE')
|
||||
|
||||
obj = client.patch()
|
||||
obj = client.patch('/')
|
||||
self.assertEqual(obj.method, 'PATCH')
|
||||
|
||||
obj = client.delete()
|
||||
obj = client.delete('/')
|
||||
self.assertEqual(obj.method, 'DELETE')
|
||||
|
||||
def test_format_url(self):
|
||||
|
@ -288,53 +249,60 @@ class TestServiceClient(unittest.TestCase):
|
|||
|
||||
|
||||
def test_client_send(self):
|
||||
|
||||
class MockHTTPDriver(object):
|
||||
def configure_session(self, **config):
|
||||
pass
|
||||
def send(self, request, **config):
|
||||
pass
|
||||
current_ua = self.cfg.user_agent
|
||||
|
||||
client = ServiceClient(self.creds, self.cfg)
|
||||
client.config.keep_alive = True
|
||||
|
||||
req_response = requests.Response()
|
||||
req_response._content = br'{"real": true}' # Has to be valid bytes JSON
|
||||
req_response._content_consumed = True
|
||||
req_response.status_code = 200
|
||||
|
||||
def side_effect(*args, **kwargs):
|
||||
return req_response
|
||||
|
||||
session = mock.create_autospec(requests.Session)
|
||||
client._http_driver.session = session
|
||||
session.request.side_effect = side_effect
|
||||
session.adapters = {
|
||||
"http://": HTTPAdapter(),
|
||||
"https://": HTTPAdapter(),
|
||||
}
|
||||
client.creds.signed_session.return_value = session
|
||||
client.creds.refresh_session.return_value = session
|
||||
# Be sure the mock does not trick me
|
||||
assert not hasattr(session.resolve_redirects, 'is_mrest_patched')
|
||||
assert not hasattr(session.resolve_redirects, 'is_msrest_patched')
|
||||
|
||||
request = ClientRequest('GET')
|
||||
client.config.pipeline._sender.driver.session = session
|
||||
|
||||
client._creds.signed_session.return_value = session
|
||||
client._creds.refresh_session.return_value = session
|
||||
|
||||
request = ClientRequest('GET', '/')
|
||||
client.send(request, stream=False)
|
||||
session.request.call_count = 0
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
None,
|
||||
'/',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': self.cfg.user_agent,
|
||||
'User-Agent': current_ua,
|
||||
'Test': 'true' # From global config
|
||||
},
|
||||
stream=False,
|
||||
timeout=100,
|
||||
verify=True
|
||||
)
|
||||
assert session.resolve_redirects.is_mrest_patched
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, stream=False)
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
None,
|
||||
'/',
|
||||
data='{"Test": "Data"}',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': self.cfg.user_agent,
|
||||
'User-Agent': current_ua,
|
||||
'Content-Length': '16',
|
||||
'id':'1234',
|
||||
'Test': 'true' # From global config
|
||||
|
@ -345,19 +313,19 @@ class TestServiceClient(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(session.request.call_count, 1)
|
||||
session.request.call_count = 0
|
||||
assert session.resolve_redirects.is_mrest_patched
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
session.request.side_effect = requests.RequestException("test")
|
||||
with self.assertRaises(ClientRequestError):
|
||||
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value', stream=False)
|
||||
session.request.assert_called_with(
|
||||
'GET',
|
||||
None,
|
||||
'/',
|
||||
data='{"Test": "Data"}',
|
||||
allow_redirects=True,
|
||||
cert=None,
|
||||
headers={
|
||||
'User-Agent': self.cfg.user_agent,
|
||||
'User-Agent': current_ua,
|
||||
'Content-Length': '16',
|
||||
'id':'1234',
|
||||
'Test': 'true' # From global config
|
||||
|
@ -368,7 +336,7 @@ class TestServiceClient(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(session.request.call_count, 1)
|
||||
session.request.call_count = 0
|
||||
assert session.resolve_redirects.is_mrest_patched
|
||||
assert session.resolve_redirects.is_msrest_patched
|
||||
|
||||
session.request.side_effect = oauth2.rfc6749.errors.InvalidGrantError("test")
|
||||
with self.assertRaises(TokenExpiredError):
|
||||
|
@ -380,77 +348,95 @@ class TestServiceClient(unittest.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
client.send(request, headers={'id':'1234'}, content={'Test':'Data'}, test='value')
|
||||
|
||||
def test_client_formdata_add(self):
|
||||
@mock.patch.object(ClientRequest, "_format_data")
|
||||
def test_client_formdata_add(self, format_data):
|
||||
format_data.return_value = "formatted"
|
||||
|
||||
client = mock.create_autospec(ServiceClient)
|
||||
client._format_data.return_value = "formatted"
|
||||
|
||||
request = ClientRequest('GET')
|
||||
ServiceClient._add_formdata(client, request)
|
||||
request = ClientRequest('GET', '/')
|
||||
request.add_formdata()
|
||||
assert request.files == {}
|
||||
|
||||
request = ClientRequest('GET')
|
||||
ServiceClient._add_formdata(client, request, {'Test':'Data'})
|
||||
request = ClientRequest('GET', '/')
|
||||
request.add_formdata({'Test':'Data'})
|
||||
assert request.files == {'Test':'formatted'}
|
||||
|
||||
request = ClientRequest('GET')
|
||||
request = ClientRequest('GET', '/')
|
||||
request.headers = {'Content-Type':'1234'}
|
||||
ServiceClient._add_formdata(client, request, {'1':'1', '2':'2'})
|
||||
request.add_formdata({'1':'1', '2':'2'})
|
||||
assert request.files == {'1':'formatted', '2':'formatted'}
|
||||
|
||||
request = ClientRequest('GET')
|
||||
request = ClientRequest('GET', '/')
|
||||
request.headers = {'Content-Type':'1234'}
|
||||
ServiceClient._add_formdata(client, request, {'1':'1', '2':None})
|
||||
request.add_formdata({'1':'1', '2':None})
|
||||
assert request.files == {'1':'formatted'}
|
||||
|
||||
request = ClientRequest('GET')
|
||||
request = ClientRequest('GET', '/')
|
||||
request.headers = {'Content-Type':'application/x-www-form-urlencoded'}
|
||||
ServiceClient._add_formdata(client, request, {'1':'1', '2':'2'})
|
||||
assert request.files == []
|
||||
request.add_formdata({'1':'1', '2':'2'})
|
||||
assert request.files is None
|
||||
assert request.data == {'1':'1', '2':'2'}
|
||||
|
||||
request = ClientRequest('GET')
|
||||
request = ClientRequest('GET', '/')
|
||||
request.headers = {'Content-Type':'application/x-www-form-urlencoded'}
|
||||
ServiceClient._add_formdata(client, request, {'1':'1', '2':None})
|
||||
assert request.files == []
|
||||
request.add_formdata({'1':'1', '2':None})
|
||||
assert request.files is None
|
||||
assert request.data == {'1':'1'}
|
||||
|
||||
def test_format_data(self):
|
||||
|
||||
mock_client = mock.create_autospec(ServiceClient)
|
||||
data = ServiceClient._format_data(mock_client, None)
|
||||
data = ClientRequest._format_data(None)
|
||||
self.assertEqual(data, (None, None))
|
||||
|
||||
data = ServiceClient._format_data(mock_client, "Test")
|
||||
data = ClientRequest._format_data("Test")
|
||||
self.assertEqual(data, (None, "Test"))
|
||||
|
||||
mock_stream = mock.create_autospec(io.BytesIO)
|
||||
data = ServiceClient._format_data(mock_client, mock_stream)
|
||||
data = ClientRequest._format_data(mock_stream)
|
||||
self.assertEqual(data, (None, mock_stream, "application/octet-stream"))
|
||||
|
||||
mock_stream.name = "file_name"
|
||||
data = ServiceClient._format_data(mock_client, mock_stream)
|
||||
data = ClientRequest._format_data(mock_stream)
|
||||
self.assertEqual(data, ("file_name", mock_stream, "application/octet-stream"))
|
||||
|
||||
def test_client_stream_download(self):
|
||||
|
||||
req_response = requests.Response()
|
||||
req_response._content = "abc"
|
||||
req_response._content_consumed = True
|
||||
req_response.status_code = 200
|
||||
|
||||
client_response = RequestsClientResponse(
|
||||
None,
|
||||
req_response
|
||||
)
|
||||
|
||||
def user_callback(chunk, response):
|
||||
assert response is req_response
|
||||
assert chunk in ["a", "b", "c"]
|
||||
|
||||
sync_iterator = client_response.stream_download(1, user_callback)
|
||||
result = ""
|
||||
for value in sync_iterator:
|
||||
result += value
|
||||
assert result == "abc"
|
||||
|
||||
def test_request_builder(self):
|
||||
client = ServiceClient(self.creds, self.cfg)
|
||||
|
||||
req = client.get('http://example.org')
|
||||
req = client.get('http://127.0.0.1/')
|
||||
assert req.method == 'GET'
|
||||
assert req.url == 'http://example.org'
|
||||
assert req.params == {}
|
||||
assert req.url == 'http://127.0.0.1/'
|
||||
assert req.headers == {'Accept': 'application/json'}
|
||||
assert req.data == []
|
||||
assert req.files == []
|
||||
assert req.data is None
|
||||
assert req.files is None
|
||||
|
||||
req = client.put('http://example.org', content={'creation': True})
|
||||
req = client.put("http://127.0.0.1/", content={'creation': True})
|
||||
assert req.method == 'PUT'
|
||||
assert req.url == 'http://example.org'
|
||||
assert req.params == {}
|
||||
assert req.url == "http://127.0.0.1/"
|
||||
assert req.headers == {'Content-Length': '18', 'Accept': 'application/json'}
|
||||
assert req.data == '{"creation": true}'
|
||||
assert req.files == []
|
||||
assert req.files is None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
|
@ -89,15 +89,16 @@ class TestExceptions(unittest.TestCase):
|
|||
'ErrorDetails': ErrorDetails
|
||||
})
|
||||
|
||||
response = mock.create_autospec(requests.Response)
|
||||
response.text = json.dumps(
|
||||
response = requests.Response()
|
||||
response._content_consumed = True
|
||||
response._content = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"code": "NotOptedIn",
|
||||
"message": "You are not allowed to download invoices. Please contact your account administrator to turn on access in the management portal for allowing to download invoices through the API."
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
).encode('utf-8')
|
||||
response.headers = {"content-type": "application/json; charset=utf8"}
|
||||
|
||||
excep = ErrorResponseException(deserializer, response)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
|
@ -34,20 +34,53 @@ try:
|
|||
except ImportError:
|
||||
import mock
|
||||
import xml.etree.ElementTree as ET
|
||||
import sys
|
||||
|
||||
from msrest.pipeline import (
|
||||
import pytest
|
||||
|
||||
from msrest.universal_http import (
|
||||
ClientRequest,
|
||||
ClientRawResponse)
|
||||
)
|
||||
from msrest.pipeline import (
|
||||
ClientRawResponse,
|
||||
SansIOHTTPPolicy,
|
||||
Pipeline,
|
||||
HTTPSender
|
||||
)
|
||||
|
||||
from msrest import Configuration
|
||||
|
||||
|
||||
def test_sans_io_exception():
|
||||
class BrokenSender(HTTPSender):
|
||||
def send(self, request, **config):
|
||||
raise ValueError("Broken")
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
pipeline = Pipeline([SansIOHTTPPolicy()], BrokenSender())
|
||||
|
||||
req = ClientRequest('GET', '/')
|
||||
with pytest.raises(ValueError):
|
||||
pipeline.run(req)
|
||||
|
||||
class SwapExec(SansIOHTTPPolicy):
|
||||
def on_exception(self, requests, **kwargs):
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
raise NotImplementedError(exc_value)
|
||||
|
||||
pipeline = Pipeline([SwapExec()], BrokenSender())
|
||||
with pytest.raises(NotImplementedError):
|
||||
pipeline.run(req)
|
||||
|
||||
|
||||
class TestClientRequest(unittest.TestCase):
|
||||
|
||||
def test_request_data(self):
|
||||
|
||||
request = ClientRequest()
|
||||
request = ClientRequest('GET', '/')
|
||||
data = "Lots of dataaaa"
|
||||
request.add_content(data)
|
||||
|
||||
|
@ -55,7 +88,7 @@ class TestClientRequest(unittest.TestCase):
|
|||
self.assertEqual(request.headers.get('Content-Length'), '17')
|
||||
|
||||
def test_request_xml(self):
|
||||
request = ClientRequest()
|
||||
request = ClientRequest('GET', '/')
|
||||
data = ET.Element("root")
|
||||
request.add_content(data)
|
||||
|
||||
|
@ -63,7 +96,7 @@ class TestClientRequest(unittest.TestCase):
|
|||
|
||||
def test_request_url_with_params(self):
|
||||
|
||||
request = ClientRequest()
|
||||
request = ClientRequest('GET', '/')
|
||||
request.url = "a/b/c?t=y"
|
||||
request.format_parameters({'g': 'h'})
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
|
@ -34,6 +34,7 @@ import pytest
|
|||
from msrest.polling import *
|
||||
from msrest.service_client import ServiceClient
|
||||
from msrest.serialization import Model
|
||||
from msrest.configuration import Configuration
|
||||
|
||||
|
||||
def test_abc_polling():
|
||||
|
@ -103,11 +104,13 @@ class PollingTwoSteps(PollingMethod):
|
|||
def resource(self):
|
||||
return self._deserialization_callback(self._initial_response)
|
||||
|
||||
def test_poller():
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# We need a ServiceClient instance, but the poller itself don't use it, so we don't need
|
||||
# Something functionnal
|
||||
client = ServiceClient(None, None)
|
||||
return ServiceClient(None, Configuration("http://example.org"))
|
||||
|
||||
def test_poller(client):
|
||||
|
||||
# Same the poller itself doesn't care about the initial_response, and there is no type constraint here
|
||||
initial_response = "Initial response"
|
||||
|
@ -115,7 +118,7 @@ def test_poller():
|
|||
# Same for deserialization_callback, just pass to the polling_method
|
||||
def deserialization_callback(response):
|
||||
assert response == initial_response
|
||||
return "Treated: "+response
|
||||
return "Treated: "+response
|
||||
|
||||
method = NoPolling()
|
||||
|
||||
|
@ -135,7 +138,7 @@ def test_poller():
|
|||
assert poller._polling_method._deserialization_callback == Model.deserialize
|
||||
|
||||
# Test poller that method do a run
|
||||
method = PollingTwoSteps(sleep=2)
|
||||
method = PollingTwoSteps(sleep=1)
|
||||
poller = LROPoller(client, initial_response, deserialization_callback, method)
|
||||
|
||||
done_cb = mock.MagicMock()
|
||||
|
@ -153,7 +156,7 @@ def test_poller():
|
|||
poller.remove_done_callback(done_cb)
|
||||
assert "Process is complete" in str(excinfo.value)
|
||||
|
||||
def test_broken_poller():
|
||||
def test_broken_poller(client):
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LROPoller(None, None, None, None)
|
||||
|
@ -162,10 +165,9 @@ def test_broken_poller():
|
|||
def run(self):
|
||||
raise ValueError("Something bad happened")
|
||||
|
||||
client = ServiceClient(None, None)
|
||||
initial_response = "Initial response"
|
||||
def deserialization_callback(response):
|
||||
return "Treated: "+response
|
||||
return "Treated: "+response
|
||||
|
||||
method = NoPollingError()
|
||||
poller = LROPoller(client, initial_response, deserialization_callback, method)
|
||||
|
@ -173,4 +175,3 @@ def test_broken_poller():
|
|||
with pytest.raises(ValueError) as excinfo:
|
||||
poller.result()
|
||||
assert "Something bad happened" in str(excinfo.value)
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
import concurrent.futures
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
from msrest.universal_http import (
|
||||
ClientRequest
|
||||
)
|
||||
from msrest.universal_http.requests import (
|
||||
BasicRequestsHTTPSender,
|
||||
RequestsHTTPSender,
|
||||
RequestHTTPSenderConfiguration
|
||||
)
|
||||
|
||||
def test_session_callback():
|
||||
|
||||
cfg = RequestHTTPSenderConfiguration()
|
||||
with RequestsHTTPSender(cfg) as driver:
|
||||
|
||||
def callback(session, global_config, local_config, **kwargs):
|
||||
assert session is driver.session
|
||||
assert global_config is cfg
|
||||
assert local_config["test"]
|
||||
my_kwargs = kwargs.copy()
|
||||
my_kwargs.update({'used_callback': True})
|
||||
return my_kwargs
|
||||
|
||||
cfg.session_configuration_callback = callback
|
||||
|
||||
request = ClientRequest('GET', 'http://127.0.0.1/')
|
||||
output_kwargs = driver._configure_send(request, **{"test": True})
|
||||
assert output_kwargs['used_callback']
|
||||
|
||||
def test_max_retries_on_default_adapter():
|
||||
# max_retries must be applied only on the default adapters of requests
|
||||
# If the user adds its own adapter, don't touch it
|
||||
cfg = RequestHTTPSenderConfiguration()
|
||||
max_retries = cfg.retry_policy()
|
||||
|
||||
with RequestsHTTPSender(cfg) as driver:
|
||||
request = ClientRequest('GET', '/')
|
||||
driver.session.mount('"http://127.0.0.1/"', HTTPAdapter())
|
||||
|
||||
driver._configure_send(request)
|
||||
assert driver.session.adapters["http://"].max_retries is max_retries
|
||||
assert driver.session.adapters["https://"].max_retries is max_retries
|
||||
assert driver.session.adapters['"http://127.0.0.1/"'].max_retries is not max_retries
|
||||
|
||||
|
||||
def test_threading_basic_requests():
|
||||
# Basic should have the session for all threads, it's why it's not recommended
|
||||
sender = BasicRequestsHTTPSender()
|
||||
main_thread_session = sender.session
|
||||
|
||||
def thread_body(local_sender):
|
||||
# Should be the same session
|
||||
assert local_sender.session is main_thread_session
|
||||
|
||||
return True
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(thread_body, sender)
|
||||
assert future.result()
|
||||
|
||||
def test_threading_cfg_requests():
|
||||
cfg = RequestHTTPSenderConfiguration()
|
||||
|
||||
# The one with conf however, should have one session per thread automatically
|
||||
sender = RequestsHTTPSender(cfg)
|
||||
main_thread_session = sender.session
|
||||
|
||||
# Check that this main session is patched
|
||||
assert main_thread_session.resolve_redirects.is_msrest_patched
|
||||
|
||||
def thread_body(local_sender):
|
||||
# Should have it's own session
|
||||
assert local_sender.session is not main_thread_session
|
||||
# But should be patched as the main thread session
|
||||
assert local_sender.session.resolve_redirects.is_msrest_patched
|
||||
return True
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(thread_body, sender)
|
||||
assert future.result()
|
|
@ -1,6 +1,6 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
|
@ -45,8 +45,9 @@ except ImportError:
|
|||
from msrest.authentication import (
|
||||
Authentication,
|
||||
OAuthTokenAuthentication)
|
||||
from msrest.pipeline import (
|
||||
ClientRequest)
|
||||
from msrest.universal_http import (
|
||||
ClientRequest
|
||||
)
|
||||
from msrest import (
|
||||
ServiceClient,
|
||||
Configuration)
|
||||
|
@ -54,12 +55,13 @@ from msrest.exceptions import (
|
|||
TokenExpiredError,
|
||||
ClientRequestError)
|
||||
|
||||
import pytest
|
||||
|
||||
class TestRuntime(unittest.TestCase):
|
||||
|
||||
@httpretty.activate
|
||||
def test_credential_headers(self):
|
||||
|
||||
|
||||
httpretty.register_uri(httpretty.GET, "https://my_service.com/get_endpoint",
|
||||
body='[{"title": "Test Data"}]',
|
||||
content_type="application/json")
|
||||
|
@ -79,10 +81,12 @@ class TestRuntime(unittest.TestCase):
|
|||
url = client.format_url("/get_endpoint")
|
||||
request = client.get(url, {'check':True})
|
||||
response = client.send(request)
|
||||
self.assertTrue('Authorization' in response.request.headers)
|
||||
self.assertEqual(response.request.headers['Authorization'], 'Bearer eswfld123kjhn1v5423')
|
||||
check = httpretty.last_request()
|
||||
self.assertEqual(response.json(), [{"title": "Test Data"}])
|
||||
assert 'Authorization' in response.request.headers
|
||||
assert response.request.headers['Authorization'] == 'Bearer eswfld123kjhn1v5423'
|
||||
httpretty.has_request()
|
||||
assert response.json() == [{"title": "Test Data"}]
|
||||
|
||||
# Expiration test
|
||||
|
||||
token['expires_in'] = '-30'
|
||||
creds = OAuthTokenAuthentication("client_id", token)
|
||||
|
@ -90,13 +94,13 @@ class TestRuntime(unittest.TestCase):
|
|||
url = client.format_url("/get_endpoint")
|
||||
request = client.get(url, {'check':True})
|
||||
|
||||
with self.assertRaises(TokenExpiredError):
|
||||
with pytest.raises(TokenExpiredError):
|
||||
response = client.send(request)
|
||||
|
||||
@mock.patch.object(requests, 'Session')
|
||||
def test_request_fail(self, mock_requests):
|
||||
|
||||
mock_requests.return_value.request.return_value = mock.Mock(_content_consumed=True)
|
||||
mock_requests.return_value.request.return_value = mock.Mock(text="text")
|
||||
|
||||
cfg = Configuration("https://my_service.com")
|
||||
creds = Authentication()
|
||||
|
@ -106,9 +110,7 @@ class TestRuntime(unittest.TestCase):
|
|||
request = client.get(url, {'check':True})
|
||||
response = client.send(request)
|
||||
|
||||
check = httpretty.last_request()
|
||||
|
||||
self.assertTrue(response._content_consumed)
|
||||
assert response.text == "text"
|
||||
|
||||
mock_requests.return_value.request.side_effect = requests.RequestException
|
||||
with self.assertRaises(ClientRequestError):
|
||||
|
@ -131,7 +133,7 @@ class TestRuntime(unittest.TestCase):
|
|||
url = client.format_url("/get_endpoint")
|
||||
request = client.get(url, {'check':True})
|
||||
response = client.send(request)
|
||||
self.assertEqual(response.json(), "Mocked body")
|
||||
assert response.json() == "Mocked body"
|
||||
|
||||
with mock.patch.dict('os.environ', {'HTTP_PROXY': "http://localhost:1987"}):
|
||||
httpretty.register_uri(httpretty.GET, "http://localhost:1987/get_endpoint?check=True",
|
||||
|
@ -144,7 +146,7 @@ class TestRuntime(unittest.TestCase):
|
|||
url = client.format_url("/get_endpoint")
|
||||
request = client.get(url, {'check':True})
|
||||
response = client.send(request)
|
||||
self.assertEqual(response.json(), "Mocked body")
|
||||
assert response.json() == "Mocked body"
|
||||
|
||||
|
||||
class TestRedirect(unittest.TestCase):
|
||||
|
@ -171,14 +173,14 @@ class TestRedirect(unittest.TestCase):
|
|||
responses=[
|
||||
httpretty.Response(body="", status=303, method='POST', location='/http/success/get/200'),
|
||||
])
|
||||
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 200, msg="Should redirect with GET on 303 with location header")
|
||||
self.assertEqual(response.request.method, 'GET')
|
||||
|
||||
self.assertEqual(response.history[0].status_code, 303)
|
||||
self.assertTrue(response.history[0].is_redirect)
|
||||
|
||||
response = self.client.send(request)
|
||||
assert response.status_code == 200, "Should redirect with GET on 303 with location header"
|
||||
assert response.request.method == 'GET'
|
||||
|
||||
assert response.history[0].status_code == 303
|
||||
assert response.history[0].is_redirect
|
||||
|
||||
httpretty.reset()
|
||||
httpretty.register_uri(httpretty.POST, "https://my_service.com/get_endpoint",
|
||||
|
@ -187,9 +189,9 @@ class TestRedirect(unittest.TestCase):
|
|||
])
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 303, msg="Should not redirect on 303 without location header")
|
||||
self.assertEqual(response.history, [])
|
||||
self.assertFalse(response.is_redirect)
|
||||
assert response.status_code == 303, "Should not redirect on 303 without location header"
|
||||
assert response.history == []
|
||||
assert not response.is_redirect
|
||||
|
||||
@httpretty.activate
|
||||
def test_request_redirect_head(self):
|
||||
|
@ -202,14 +204,14 @@ class TestRedirect(unittest.TestCase):
|
|||
responses=[
|
||||
httpretty.Response(body="", status=307, method='HEAD', location='/http/success/200'),
|
||||
])
|
||||
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 200, msg="Should redirect on 307 with location header")
|
||||
self.assertEqual(response.request.method, 'HEAD')
|
||||
|
||||
self.assertEqual(response.history[0].status_code, 307)
|
||||
self.assertTrue(response.history[0].is_redirect)
|
||||
|
||||
response = self.client.send(request)
|
||||
assert response.status_code == 200, "Should redirect on 307 with location header"
|
||||
assert response.request.method == 'HEAD'
|
||||
|
||||
assert response.history[0].status_code == 307
|
||||
assert response.history[0].is_redirect
|
||||
|
||||
httpretty.reset()
|
||||
httpretty.register_uri(httpretty.HEAD, "https://my_service.com/get_endpoint",
|
||||
|
@ -218,9 +220,9 @@ class TestRedirect(unittest.TestCase):
|
|||
])
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 307, msg="Should not redirect on 307 without location header")
|
||||
self.assertEqual(response.history, [])
|
||||
self.assertFalse(response.is_redirect)
|
||||
assert response.status_code == 307, "Should not redirect on 307 without location header"
|
||||
assert response.history == []
|
||||
assert not response.is_redirect
|
||||
|
||||
@httpretty.activate
|
||||
def test_request_redirect_delete(self):
|
||||
|
@ -233,14 +235,14 @@ class TestRedirect(unittest.TestCase):
|
|||
responses=[
|
||||
httpretty.Response(body="", status=307, method='DELETE', location='/http/success/200'),
|
||||
])
|
||||
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 200, msg="Should redirect on 307 with location header")
|
||||
self.assertEqual(response.request.method, 'DELETE')
|
||||
|
||||
self.assertEqual(response.history[0].status_code, 307)
|
||||
self.assertTrue(response.history[0].is_redirect)
|
||||
|
||||
response = self.client.send(request)
|
||||
assert response.status_code == 200, "Should redirect on 307 with location header"
|
||||
assert response.request.method == 'DELETE'
|
||||
|
||||
assert response.history[0].status_code == 307
|
||||
assert response.history[0].is_redirect
|
||||
|
||||
httpretty.reset()
|
||||
httpretty.register_uri(httpretty.DELETE, "https://my_service.com/get_endpoint",
|
||||
|
@ -249,9 +251,9 @@ class TestRedirect(unittest.TestCase):
|
|||
])
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 307, msg="Should not redirect on 307 without location header")
|
||||
self.assertEqual(response.history, [])
|
||||
self.assertFalse(response.is_redirect)
|
||||
assert response.status_code == 307, "Should not redirect on 307 without location header"
|
||||
assert response.history == []
|
||||
assert not response.is_redirect
|
||||
|
||||
@httpretty.activate
|
||||
def test_request_redirect_put(self):
|
||||
|
@ -265,9 +267,9 @@ class TestRedirect(unittest.TestCase):
|
|||
])
|
||||
|
||||
response = self.client.send(request)
|
||||
self.assertEqual(response.status_code, 305, msg="Should not redirect on 305")
|
||||
self.assertEqual(response.history, [])
|
||||
self.assertFalse(response.is_redirect)
|
||||
assert response.status_code == 305, "Should not redirect on 305"
|
||||
assert response.history == []
|
||||
assert not response.is_redirect
|
||||
|
||||
@httpretty.activate
|
||||
def test_request_redirect_get(self):
|
||||
|
@ -301,7 +303,7 @@ class TestRedirect(unittest.TestCase):
|
|||
])
|
||||
|
||||
with self.assertRaises(ClientRequestError, msg="Should exceed maximum redirects"):
|
||||
response = self.client.send(request)
|
||||
self.client.send(request)
|
||||
|
||||
|
||||
|
||||
|
@ -325,8 +327,8 @@ class TestRuntimeRetry(unittest.TestCase):
|
|||
httpretty.Response(body="retry response", status=502),
|
||||
httpretty.Response(body='success response', status=202),
|
||||
])
|
||||
|
||||
|
||||
|
||||
|
||||
response = self.client.send(self.request)
|
||||
self.assertEqual(response.status_code, 202, msg="Should retry on 502")
|
||||
|
||||
|
@ -364,8 +366,8 @@ class TestRuntimeRetry(unittest.TestCase):
|
|||
])
|
||||
|
||||
with self.assertRaises(ClientRequestError, msg="Max retries reached"):
|
||||
response = self.client.send(self.request)
|
||||
|
||||
self.client.send(self.request)
|
||||
|
||||
@httpretty.activate
|
||||
def test_request_retry_404(self):
|
||||
httpretty.register_uri(httpretty.GET, "https://my_service.com/get_endpoint",
|
||||
|
@ -399,6 +401,6 @@ class TestRuntimeRetry(unittest.TestCase):
|
|||
response = self.client.send(self.request)
|
||||
self.assertEqual(response.status_code, 505, msg="Shouldn't retry on 505")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -31,10 +31,6 @@ import logging
|
|||
from enum import Enum
|
||||
from datetime import datetime, timedelta, date
|
||||
import unittest
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
@ -180,10 +176,7 @@ class TestModelDeserialization(unittest.TestCase):
|
|||
"location": "westus"
|
||||
}
|
||||
|
||||
resp = mock.create_autospec(Response)
|
||||
resp.text = json.dumps(data)
|
||||
resp.headers = {"content-type": "application/json; charset=utf8"}
|
||||
model = self.d('GenericResource', resp)
|
||||
model = self.d('GenericResource', json.dumps(data), 'application/json')
|
||||
self.assertEqual(model.properties['platformFaultDomainCount'], 3)
|
||||
self.assertEqual(model.location, 'westus')
|
||||
|
||||
|
@ -1305,47 +1298,6 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
self.d = Deserializer()
|
||||
return super(TestRuntimeDeserialized, self).setUp()
|
||||
|
||||
def test_unpack(self):
|
||||
result = Deserializer._unpack_content("<groot/>", content_type="application/xml")
|
||||
assert result.tag == "groot"
|
||||
|
||||
# Catch some weird situation where content_type is XML, but content is JSON
|
||||
result = Deserializer._unpack_content('{"ugly": true}', content_type="application/xml")
|
||||
assert result["ugly"] is True
|
||||
|
||||
# Be sure I catch the correct exception if it's neither XML nor JSON
|
||||
with pytest.raises(ET.ParseError):
|
||||
result = Deserializer._unpack_content('gibberish', content_type="application/xml")
|
||||
with pytest.raises(ET.ParseError):
|
||||
result = Deserializer._unpack_content('{{gibberish}}', content_type="application/xml")
|
||||
|
||||
result = Deserializer._unpack_content('{"success": true}', content_type="application/json")
|
||||
assert result["success"] is True
|
||||
|
||||
# For compat, if no content-type, and direct string, just return it
|
||||
result = Deserializer._unpack_content('data')
|
||||
assert result == "data"
|
||||
|
||||
# Decore bytes
|
||||
result = Deserializer._unpack_content(b'data')
|
||||
assert result == "data"
|
||||
|
||||
response = Response()
|
||||
response.headers["content-type"] = "application/json"
|
||||
response._content = b'{"success": true}'
|
||||
response._content_consumed = True
|
||||
|
||||
result = Deserializer._unpack_content(response)
|
||||
assert result["success"] is True
|
||||
|
||||
# If no content-type, assume it's JSON
|
||||
response = Response()
|
||||
response._content = b'{"success": true}'
|
||||
response._content_consumed = True
|
||||
|
||||
result = Deserializer._unpack_content(response)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_cls_method_deserialization(self):
|
||||
json_data = {
|
||||
'id': 'myid',
|
||||
|
@ -1517,80 +1469,56 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
"""
|
||||
Test invalid JSON
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.headers = {"content-type": "application/json; charset=utf8"}
|
||||
response_data.text = '["tata"]]'
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d("[str]", response_data)
|
||||
self.d("[str]", '["tata"]]', 'application/json')
|
||||
|
||||
|
||||
def test_non_obj_deserialization(self):
|
||||
"""
|
||||
Test direct deserialization of simple types.
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.headers = {"content-type": "application/json; charset=utf8"}
|
||||
|
||||
response_data.text = ''
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d("[str]", response_data)
|
||||
self.d("[str]", '', 'application/json')
|
||||
|
||||
response_data.text = json.dumps('')
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d("[str]", response_data)
|
||||
self.d("[str]", json.dumps(''), 'application/json')
|
||||
|
||||
response_data.text = json.dumps({})
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d("[str]", response_data)
|
||||
self.d("[str]", json.dumps({}), 'application/json')
|
||||
|
||||
message = ["a","b","b"]
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d("[str]", response_data)
|
||||
response = self.d("[str]", json.dumps(message), 'application/json')
|
||||
self.assertEqual(response, message)
|
||||
|
||||
response_data.text = json.dumps(12345)
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d("[str]", response_data)
|
||||
self.d("[str]", json.dumps(12345), 'application/json')
|
||||
|
||||
response_data.text = json.dumps('true')
|
||||
response = self.d('bool', response_data)
|
||||
response = self.d('bool', json.dumps('true'), 'application/json')
|
||||
self.assertEqual(response, True)
|
||||
|
||||
response_data.text = json.dumps(1)
|
||||
response = self.d('bool', response_data)
|
||||
response = self.d('bool', json.dumps(1), 'application/json')
|
||||
self.assertEqual(response, True)
|
||||
|
||||
response_data.text = json.dumps("true1")
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d('bool', response_data)
|
||||
self.d('bool', json.dumps("true1"), 'application/json')
|
||||
|
||||
|
||||
def test_obj_with_no_attr(self):
|
||||
"""
|
||||
Test deserializing an object with no attributes.
|
||||
"""
|
||||
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.text = json.dumps({"a":"b"})
|
||||
response_data.headers = {"content-type": "application/json; charset=utf8"}
|
||||
|
||||
class EmptyResponse(Model):
|
||||
_attribute_map = {}
|
||||
_header_map = {}
|
||||
|
||||
|
||||
derserialized = self.d(EmptyResponse, response_data)
|
||||
self.assertIsInstance(derserialized, EmptyResponse)
|
||||
deserialized = self.d(EmptyResponse, json.dumps({"a":"b"}), 'application/json')
|
||||
self.assertIsInstance(deserialized, EmptyResponse)
|
||||
|
||||
def test_obj_with_malformed_map(self):
|
||||
"""
|
||||
Test deserializing an object with a malformed attributes_map.
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.text = json.dumps({"a":"b"})
|
||||
response_data.headers = {"content-type": "application/json; charset=utf8"}
|
||||
|
||||
class BadResponse(Model):
|
||||
_attribute_map = None
|
||||
|
||||
|
@ -1598,7 +1526,7 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
pass
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d(BadResponse, response_data)
|
||||
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
|
||||
|
||||
class BadResponse(Model):
|
||||
_attribute_map = {"attr":"val"}
|
||||
|
@ -1607,7 +1535,7 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
pass
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d(BadResponse, response_data)
|
||||
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
|
||||
|
||||
class BadResponse(Model):
|
||||
_attribute_map = {"attr":{"val":1}}
|
||||
|
@ -1616,141 +1544,89 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
pass
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
self.d(BadResponse, response_data)
|
||||
self.d(BadResponse, json.dumps({"a":"b"}), 'application/json')
|
||||
|
||||
def test_attr_none(self):
|
||||
"""
|
||||
Test serializing an object with None attributes.
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.headers = {"content-type": "application/json; charset=utf8"}
|
||||
response_data.text = 'null'
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, 'null', 'application/json')
|
||||
self.assertIsNone(response)
|
||||
|
||||
def test_attr_int(self):
|
||||
"""
|
||||
Test deserializing an object with Int attributes.
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id':"123",
|
||||
'etag':456.3,
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = ''
|
||||
|
||||
message = {'AttrB':'1234'}
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_b'))
|
||||
self.assertEqual(response.attr_b, int(message['AttrB']))
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
response_data.text = json.dumps({'AttrB':'NotANumber'})
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrB':'NotANumber'}), 'application/json')
|
||||
|
||||
def test_attr_str(self):
|
||||
"""
|
||||
Test deserializing an object with Str attributes.
|
||||
"""
|
||||
message = {'id':'InterestingValue'}
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id': 'a',
|
||||
'etag': 'b',
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = json.dumps(message)
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_a'))
|
||||
self.assertEqual(response.attr_a, message['id'])
|
||||
|
||||
message = {'id':1234}
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertEqual(response.attr_a, str(message['id']))
|
||||
|
||||
message = {'id':list()}
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertEqual(response.attr_a, str(message['id']))
|
||||
|
||||
response_data.text = json.dumps({'id':None})
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'id':None}), 'application/json')
|
||||
self.assertEqual(response.attr_a, None)
|
||||
|
||||
def test_attr_bool(self):
|
||||
"""
|
||||
Test deserializing an object with bool attributes.
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id': 'a',
|
||||
'etag': 'b',
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = json.dumps({'Key_C':True})
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
|
||||
response = self.d(self.TestObj, json.dumps({'Key_C':True}), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_c'))
|
||||
self.assertEqual(response.attr_c, True)
|
||||
|
||||
response_data.text = json.dumps({'Key_C':[]})
|
||||
with self.assertRaises(DeserializationError):
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'Key_C':[]}), 'application/json')
|
||||
|
||||
response_data.text = json.dumps({'Key_C':0})
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'Key_C':0}), 'application/json')
|
||||
self.assertEqual(response.attr_c, False)
|
||||
|
||||
response_data.text = json.dumps({'Key_C':"value"})
|
||||
with self.assertRaises(DeserializationError):
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'Key_C':"value"}), 'application/json')
|
||||
|
||||
def test_attr_list_simple(self):
|
||||
"""
|
||||
Test deserializing an object with simple-typed list attributes
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id': 'a',
|
||||
'etag': 'b',
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = json.dumps({'AttrD': []})
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrD': []}), 'application/json')
|
||||
deserialized_list = [d for d in response.attr_d]
|
||||
self.assertEqual(deserialized_list, [])
|
||||
|
||||
message = {'AttrD': [1,2,3]}
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
deserialized_list = [d for d in response.attr_d]
|
||||
self.assertEqual(deserialized_list, message['AttrD'])
|
||||
|
||||
message = {'AttrD': ["1","2","3"]}
|
||||
response_data.text = json.dumps(message)
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
deserialized_list = [d for d in response.attr_d]
|
||||
self.assertEqual(deserialized_list, [int(i) for i in message['AttrD']])
|
||||
|
||||
response_data.text = json.dumps({'AttrD': ["test","test2","test3"]})
|
||||
with self.assertRaises(DeserializationError):
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrD': ["test","test2","test3"]}), 'application/json')
|
||||
deserialized_list = [d for d in response.attr_d]
|
||||
|
||||
response_data.text = json.dumps({'AttrD': "NotAList"})
|
||||
with self.assertRaises(DeserializationError):
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrD': "NotAList"}), 'application/json')
|
||||
deserialized_list = [d for d in response.attr_d]
|
||||
|
||||
self.assertListEqual(sorted(self.d("[str]", ["a", "b", "c"])), ["a", "b", "c"])
|
||||
|
@ -1760,49 +1636,30 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
"""
|
||||
Test deserializing a list of lists
|
||||
"""
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id': 'a',
|
||||
'etag': 'b',
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = json.dumps({'AttrF':[]})
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrF':[]}), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_f'))
|
||||
self.assertEqual(response.attr_f, [])
|
||||
|
||||
response_data.text = json.dumps({'AttrF':None})
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrF':None}), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_f'))
|
||||
self.assertEqual(response.attr_f, None)
|
||||
|
||||
response_data.text = json.dumps({})
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
|
||||
response = self.d(self.TestObj, json.dumps({}), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_f'))
|
||||
self.assertEqual(response.attr_f, None)
|
||||
|
||||
message = {'AttrF':[[]]}
|
||||
response_data.text = json.dumps(message)
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_f'))
|
||||
self.assertEqual(response.attr_f, message['AttrF'])
|
||||
|
||||
message = {'AttrF':[[1,2,3], ['a','b','c']]}
|
||||
response_data.text = json.dumps(message)
|
||||
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps(message), 'application/json')
|
||||
self.assertTrue(hasattr(response, 'attr_f'))
|
||||
self.assertEqual(response.attr_f, [[str(i) for i in k] for k in message['AttrF']])
|
||||
|
||||
with self.assertRaises(DeserializationError):
|
||||
response_data.text = json.dumps({'AttrF':[1,2,3]})
|
||||
response = self.d(self.TestObj, response_data)
|
||||
response = self.d(self.TestObj, json.dumps({'AttrF':[1,2,3]}), 'application/json')
|
||||
|
||||
def test_attr_list_complex(self):
|
||||
"""
|
||||
|
@ -1816,17 +1673,8 @@ class TestRuntimeDeserialized(unittest.TestCase):
|
|||
_attribute_map = {'attr_a': {'key':'id', 'type':'[ListObj]'}}
|
||||
|
||||
|
||||
response_data = mock.create_autospec(Response)
|
||||
response_data.status_code = 200
|
||||
response_data.headers = {
|
||||
'client-request-id': 'a',
|
||||
'etag': 'b',
|
||||
"content-type": "application/json; charset=utf8"
|
||||
}
|
||||
response_data.text = json.dumps({"id":[{"ABC": "123"}]})
|
||||
|
||||
d = Deserializer({'ListObj':ListObj})
|
||||
response = d(CmplxTestObj, response_data)
|
||||
response = d(CmplxTestObj, json.dumps({"id":[{"ABC": "123"}]}), 'application/json')
|
||||
deserialized_list = list(response.attr_a)
|
||||
|
||||
self.assertIsInstance(deserialized_list[0], ListObj)
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
#--------------------------------------------------------------------------
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
#
|
||||
# The MIT License (MIT)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the ""Software""), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
#
|
||||
#--------------------------------------------------------------------------
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
import mock
|
||||
|
||||
import requests
|
||||
|
||||
import pytest
|
||||
|
||||
from msrest.exceptions import DeserializationError
|
||||
from msrest.universal_http import (
|
||||
ClientRequest,
|
||||
ClientResponse,
|
||||
HTTPClientResponse,
|
||||
)
|
||||
from msrest.universal_http.requests import RequestsClientResponse
|
||||
|
||||
from msrest.pipeline import (
|
||||
Response,
|
||||
Request
|
||||
)
|
||||
from msrest.pipeline.universal import (
|
||||
HTTPLogger,
|
||||
RawDeserializer,
|
||||
UserAgentPolicy
|
||||
)
|
||||
|
||||
def test_user_agent():
|
||||
|
||||
with mock.patch.dict('os.environ', {'AZURE_HTTP_USER_AGENT': "mytools"}):
|
||||
policy = UserAgentPolicy()
|
||||
assert policy.user_agent.endswith("mytools")
|
||||
|
||||
request = ClientRequest('GET', 'http://127.0.0.1/')
|
||||
policy.on_request(Request(request))
|
||||
assert request.headers["user-agent"].endswith("mytools")
|
||||
|
||||
@mock.patch('msrest.http_logger._LOGGER')
|
||||
def test_no_log(mock_http_logger):
|
||||
universal_request = ClientRequest('GET', 'http://127.0.0.1/')
|
||||
request = Request(universal_request)
|
||||
http_logger = HTTPLogger()
|
||||
response = Response(request, ClientResponse(universal_request, None))
|
||||
|
||||
# By default, no log handler for HTTP
|
||||
http_logger.on_request(request)
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
http_logger.on_response(request, response)
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
mock_http_logger.reset_mock()
|
||||
|
||||
# I can enable it per request
|
||||
http_logger.on_request(request, **{"enable_http_logger": True})
|
||||
assert mock_http_logger.debug.call_count >= 1
|
||||
http_logger.on_response(request, response, **{"enable_http_logger": True})
|
||||
assert mock_http_logger.debug.call_count >= 1
|
||||
mock_http_logger.reset_mock()
|
||||
|
||||
# I can enable it per request (bool value should be honored)
|
||||
http_logger.on_request(request, **{"enable_http_logger": False})
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
http_logger.on_response(request, response, **{"enable_http_logger": False})
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
mock_http_logger.reset_mock()
|
||||
|
||||
# I can enable it globally
|
||||
http_logger.enable_http_logger = True
|
||||
http_logger.on_request(request)
|
||||
assert mock_http_logger.debug.call_count >= 1
|
||||
http_logger.on_response(request, response)
|
||||
assert mock_http_logger.debug.call_count >= 1
|
||||
mock_http_logger.reset_mock()
|
||||
|
||||
# I can enable it globally and override it locally
|
||||
http_logger.enable_http_logger = True
|
||||
http_logger.on_request(request, **{"enable_http_logger": False})
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
http_logger.on_response(request, response, **{"enable_http_logger": False})
|
||||
mock_http_logger.debug.assert_not_called()
|
||||
mock_http_logger.reset_mock()
|
||||
|
||||
|
||||
def test_raw_deserializer():
|
||||
raw_deserializer = RawDeserializer()
|
||||
|
||||
def build_response(body, content_type=None):
|
||||
class MockResponse(HTTPClientResponse):
|
||||
def __init__(self, body, content_type):
|
||||
super(MockResponse, self).__init__(None, None)
|
||||
self._body = body
|
||||
if content_type:
|
||||
self.headers['content-type'] = content_type
|
||||
|
||||
def body(self):
|
||||
return self._body
|
||||
return Response(None, MockResponse(body, content_type))
|
||||
|
||||
response = build_response(b"<groot/>", content_type="application/xml")
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
result = response.context["deserialized_data"]
|
||||
assert result.tag == "groot"
|
||||
|
||||
# Catch some weird situation where content_type is XML, but content is JSON
|
||||
response = build_response(b'{"ugly": true}', content_type="application/xml")
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
result = response.context["deserialized_data"]
|
||||
assert result["ugly"] is True
|
||||
|
||||
# Be sure I catch the correct exception if it's neither XML nor JSON
|
||||
with pytest.raises(DeserializationError):
|
||||
response = build_response(b'gibberish', content_type="application/xml")
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
with pytest.raises(DeserializationError):
|
||||
response = build_response(b'{{gibberish}}', content_type="application/xml")
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
|
||||
# Simple JSON
|
||||
response = build_response(b'{"success": true}', content_type="application/json")
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
result = response.context["deserialized_data"]
|
||||
assert result["success"] is True
|
||||
|
||||
# For compat, if no content-type, decode JSON
|
||||
response = build_response(b'"data"')
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
result = response.context["deserialized_data"]
|
||||
assert result == "data"
|
||||
|
||||
# Try with a mock of requests
|
||||
|
||||
req_response = requests.Response()
|
||||
req_response.headers["content-type"] = "application/json"
|
||||
req_response._content = b'{"success": true}'
|
||||
req_response._content_consumed = True
|
||||
response = Response(None, RequestsClientResponse(None, req_response))
|
||||
|
||||
raw_deserializer.on_response(None, response, stream=False)
|
||||
result = response.context["deserialized_data"]
|
||||
assert result["success"] is True
|
|
@ -26,8 +26,6 @@
|
|||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import requests
|
||||
|
||||
import pytest
|
||||
|
||||
from msrest.serialization import Serializer, Deserializer, Model, xml_key_extractor
|
||||
|
@ -140,27 +138,6 @@ class TestXmlDeserialization:
|
|||
assert child.tag == "Age"
|
||||
assert child.text == "37"
|
||||
|
||||
def test_object_from_requests(self):
|
||||
basic_xml = b"""<?xml version="1.0"?>
|
||||
<Data country="france">
|
||||
<Age>37</Age>
|
||||
</Data>"""
|
||||
|
||||
response = requests.Response()
|
||||
response.headers["content-type"] = "application/xml; charset=utf-8"
|
||||
response._content = basic_xml
|
||||
response._content_consumed = True
|
||||
|
||||
s = Deserializer()
|
||||
result = s('object', response)
|
||||
|
||||
# Should be a XML tree
|
||||
assert result.tag == "Data"
|
||||
assert result.get("country") == "france"
|
||||
for child in result:
|
||||
assert child.tag == "Age"
|
||||
assert child.text == "37"
|
||||
|
||||
def test_basic_empty(self):
|
||||
"""Test an basic XML with an empty node."""
|
||||
basic_xml = """<?xml version="1.0"?>
|
||||
|
|
2
tox.ini
2
tox.ini
|
@ -13,4 +13,4 @@ commands=
|
|||
pytest --cov=msrest tests/
|
||||
autorest: pytest --cov=msrest --cov-append autorest.python/test/vanilla/
|
||||
coverage report --fail-under=40
|
||||
coverage xml
|
||||
coverage xml --ignore-errors # At this point, don't fail for "async" keyword in 2.7/3.4
|
||||
|
|
Загрузка…
Ссылка в новой задаче