Bringing in new changes from Autorest.python

This commit is contained in:
Varad Meru [gmail] 2020-08-26 23:29:30 -07:00
Родитель 2b239b3a3a
Коммит 4b8e0a34eb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: D65E2959EB74910D
23 изменённых файлов: 469 добавлений и 315 удалений

12
.gitignore поставляемый
Просмотреть файл

@ -8,7 +8,6 @@ __pycache__/
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
@ -27,6 +26,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
package-lock.json
# PyInstaller
# Usually these files are written by a python script from a template
@ -111,6 +111,7 @@ ENV/
env.bak/
venv.bak/
py3env/
env/
# Spyder project settings
.spyderproject
@ -155,4 +156,11 @@ autorest-azurefunction*gz
autorest-azure-functions*gz
# NPMrc files
.npmrc
.npmrc
# autorest default folder
generated
**/code-model-v4-no-tags.yaml
# Generated test folders
test/services/*/_generated

Просмотреть файл

@ -5,7 +5,8 @@
# --------------------------------------------------------------------------
from typing import Any, Dict
from .base_model import BaseModel
from .code_model import CodeModel, CredentialSchema
from .code_model import CodeModel
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .object_schema import ObjectSchema
from .dictionary_schema import DictionarySchema
from .list_schema import ListSchema
@ -13,6 +14,7 @@ from .primitive_schemas import get_primitive_schema, AnySchema, PrimitiveSchema
from .enum_schema import EnumSchema
from .base_schema import BaseSchema
from .constant_schema import ConstantSchema
from .credential_schema import CredentialSchema
from .imports import FileImport, ImportType, TypingSection
from .lro_operation import LROOperation
from .paging_operation import PagingOperation
@ -25,11 +27,12 @@ from .parameter_list import ParameterList
__all__ = [
"AzureKeyCredentialSchema",
"BaseModel",
"BaseSchema",
"CodeModel",
"CredentialSchema",
"ConstantSchema",
"CredentialSchema",
"ObjectSchema",
"DictionarySchema",
"ListSchema",
@ -45,7 +48,8 @@ __all__ = [
"ParameterList",
"OperationGroup",
"Property",
"SchemaResponse"
"SchemaResponse",
"TokenCredentialSchema",
]
def _generate_as_object_schema(yaml_data: Dict[str, Any]) -> bool:

Просмотреть файл

@ -46,6 +46,8 @@ class BaseSchema(BaseModel, ABC):
attrs_list.append(f"'prefix': '{self.xml_metadata['prefix']}'")
if self.xml_metadata.get("namespace", False):
attrs_list.append(f"'ns': '{self.xml_metadata['namespace']}'")
if self.xml_metadata.get("text"):
attrs_list.append("'text': True")
return ", ".join(attrs_list)
def imports(self) -> FileImport: # pylint: disable=no-self-use

Просмотреть файл

@ -28,8 +28,6 @@ class Client:
file_import.add_from_import("msrest", "Deserializer", ImportType.AZURECORE)
file_import.add_from_import("typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL)
# if code_model.options["credential"]:
# file_import.add_from_import("azure.core.credentials", "TokenCredential", ImportType.AZURECORE)
any_optional_gp = any(not gp.required for gp in code_model.global_parameters)
if any_optional_gp or code_model.base_url:
@ -43,4 +41,10 @@ class Client:
file_import.add_from_import(
"azure.core", Client.pipeline_class(code_model, async_mode), ImportType.AZURECORE
)
if not code_model.sorted_schemas:
# in this case, we have client_models = {} in the service client, which needs a type annotation
# this import will always be commented, so will always add it to the typing section
file_import.add_from_import("typing", "Dict", ImportType.STDLIB, TypingSection.TYPING)
return file_import

Просмотреть файл

@ -5,9 +5,10 @@
# --------------------------------------------------------------------------
from itertools import chain
import logging
from typing import cast, List, Dict, Optional, Any, Set
from typing import cast, List, Dict, Optional, Any, Set, Union
from .base_schema import BaseSchema
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .enum_schema import EnumSchema
from .object_schema import ObjectSchema
from .operation_group import OperationGroup
@ -17,7 +18,6 @@ from .paging_operation import PagingOperation
from .parameter import Parameter, ParameterLocation
from .client import Client
from .parameter_list import ParameterList
from .imports import FileImport, ImportType, TypingSection
from .schema_response import SchemaResponse
from .property import Property
from .primitive_schemas import IOSchema
@ -26,50 +26,6 @@ from .primitive_schemas import IOSchema
_LOGGER = logging.getLogger(__name__)
class CredentialSchema(BaseSchema):
def __init__(self, async_mode) -> None: # pylint: disable=super-init-not-called
self.async_mode = async_mode
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
self.sync_type = "~azure.core.credentials.TokenCredential"
self.default_value = None
@property
def serialization_type(self) -> str:
if self.async_mode:
return self.async_type
return self.sync_type
@property
def docstring_type(self) -> str:
return self.serialization_type
@property
def type_annotation(self) -> str:
if self.async_mode:
return '"AsyncTokenCredential"'
return '"TokenCredential"'
@property
def docstring_text(self) -> str:
return "credential"
def imports(self) -> FileImport:
file_import = FileImport()
if self.async_mode:
file_import.add_from_import(
"azure.core.credentials_async", "AsyncTokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
else:
file_import.add_from_import(
"azure.core.credentials", "TokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
return file_import
class CodeModel: # pylint: disable=too-many-instance-attributes
"""Holds all of the information we have parsed out of the yaml file. The CodeModel is what gets
serialized by the serializers.
@ -168,7 +124,11 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
:return: None
:rtype: None
"""
credential_schema = CredentialSchema(async_mode=False)
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
credential_schema = TokenCredentialSchema(async_mode=False)
else:
credential_schema = AzureKeyCredentialSchema()
credential_parameter = Parameter(
yaml_data={},
schema=credential_schema,
@ -191,6 +151,7 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
description="",
url=operation.url,
method=operation.method,
multipart=operation.multipart,
api_versions=operation.api_versions,
parameters=operation.parameters.parameters,
requests=operation.requests,
@ -246,17 +207,18 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
def _add_properties_from_inheritance_helper(schema, properties) -> List[Property]:
if not schema.base_models:
return properties
for base_model in schema.base_models:
parent = cast(ObjectSchema, base_model)
schema_property_names = [p.name for p in properties]
chosen_parent_properties = [
p for p in parent.properties
if p.name not in schema_property_names
]
properties = (
CodeModel._add_properties_from_inheritance_helper(parent, chosen_parent_properties) +
properties
)
if schema.base_models:
for base_model in schema.base_models:
parent = cast(ObjectSchema, base_model)
schema_property_names = [p.name for p in properties]
chosen_parent_properties = [
p for p in parent.properties
if p.name not in schema_property_names
]
properties = (
CodeModel._add_properties_from_inheritance_helper(parent, chosen_parent_properties) +
properties
)
return properties
@ -367,3 +329,11 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
raise ValueError("You are missing a parameter that has multiple media types")
chosen_parameter.multiple_media_types_type_annot = f"Union[{type_annot}]"
chosen_parameter.multiple_media_types_docstring_type = docstring_type
@property
def has_lro_operations(self) -> bool:
return any([
isinstance(operation, LROOperation)
for operation_group in self.operation_groups
for operation in operation_group.operations
])

Просмотреть файл

@ -30,12 +30,11 @@ class ConstantSchema(BaseSchema):
self.schema = schema
def get_declaration(self, value: Any):
raise TypeError("Should not call get_declaration on a ConstantSchema. Use the constant_value property instead")
@property
def constant_value(self) -> str:
"""This string is used directly in template, as-is
"""
if value != self.value:
_LOGGER.warning(
"Passed in value of %s differs from constant value of %s. Choosing constant value",
str(value), str(self.value)
)
if self.value is None:
return "None"
return self.schema.get_declaration(self.value)

Просмотреть файл

@ -0,0 +1,84 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from .base_schema import BaseSchema
from .imports import FileImport, ImportType, TypingSection
class CredentialSchema(BaseSchema):
def __init__(self) -> None: # pylint: disable=super-init-not-called
self.default_value = None
@property
def docstring_type(self) -> str:
return self.serialization_type
@property
def docstring_text(self) -> str:
return "credential"
@property
def serialization_type(self) -> str:
# this property is added, because otherwise pylint says that
# abstract serialization_type in BaseSchema is not overridden
pass
class AzureKeyCredentialSchema(CredentialSchema):
@property
def serialization_type(self) -> str:
return "~azure.core.credentials.AzureKeyCredential"
@property
def type_annotation(self) -> str:
return "AzureKeyCredential"
def imports(self) -> FileImport:
file_import = FileImport()
file_import.add_from_import(
"azure.core.credentials",
"AzureKeyCredential",
ImportType.AZURECORE,
typing_section=TypingSection.CONDITIONAL
)
return file_import
class TokenCredentialSchema(CredentialSchema):
def __init__(self, async_mode) -> None:
super(TokenCredentialSchema, self).__init__()
self.async_mode = async_mode
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
self.sync_type = "~azure.core.credentials.TokenCredential"
@property
def serialization_type(self) -> str:
if self.async_mode:
return self.async_type
return self.sync_type
@property
def type_annotation(self) -> str:
if self.async_mode:
return '"AsyncTokenCredential"'
return '"TokenCredential"'
def imports(self) -> FileImport:
file_import = FileImport()
if self.async_mode:
file_import.add_from_import(
"azure.core.credentials_async", "AsyncTokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
else:
file_import.add_from_import(
"azure.core.credentials", "TokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
return file_import

Просмотреть файл

@ -13,7 +13,6 @@ class ImportType(str, Enum):
AZURECORE = "azurecore"
LOCAL = "local"
class TypingSection(str, Enum):
REGULAR = "regular" # this import is always a typing import
CONDITIONAL = "conditional" # is a typing import when we're dealing with files that py2 will use, else regular

Просмотреть файл

@ -24,6 +24,7 @@ class LROOperation(Operation):
description: str,
url: str,
method: str,
multipart: bool,
api_versions: Set[str],
requests: List[SchemaRequest],
summary: Optional[str] = None,
@ -32,7 +33,7 @@ class LROOperation(Operation):
responses: Optional[List[SchemaResponse]] = None,
exceptions: Optional[List[SchemaResponse]] = None,
want_description_docstring: bool = True,
want_tracing: bool = True,
want_tracing: bool = True
) -> None:
super(LROOperation, self).__init__(
yaml_data,
@ -40,6 +41,7 @@ class LROOperation(Operation):
description,
url,
method,
multipart,
api_versions,
requests,
summary,
@ -65,10 +67,11 @@ class LROOperation(Operation):
rt for rt in response_types if 200 in rt.status_codes
]
if not response_types_with_200_status_code:
raise ValueError("Your swagger is invalid because you have "
"multiple response schemas for LRO method "
f"{self.python_name} and none of them have a "
"200 status code.")
raise ValueError(
"Your swagger is invalid because you have multiple response"
" schemas for LRO" + f" method {self.python_name} and "
"none of them have a 200 status code."
)
response_type = response_types_with_200_status_code[0]
response_type_schema_name = cast(BaseSchema, response_type.schema).serialization_type

Просмотреть файл

@ -0,0 +1,59 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Any, Dict, List, Set, Optional
from .lro_operation import LROOperation
from .paging_operation import PagingOperation
from .imports import FileImport
from .schema_request import SchemaRequest
from .parameter import Parameter
from .schema_response import SchemaResponse
class LROPagingOperation(PagingOperation, LROOperation):
def __init__(
self,
yaml_data: Dict[str, Any],
name: str,
description: str,
url: str,
method: str,
multipart: bool,
api_versions: Set[str],
requests: List[SchemaRequest],
summary: Optional[str] = None,
parameters: Optional[List[Parameter]] = None,
multiple_media_type_parameters: Optional[List[Parameter]] = None,
responses: Optional[List[SchemaResponse]] = None,
exceptions: Optional[List[SchemaResponse]] = None,
want_description_docstring: bool = True,
want_tracing: bool = True
) -> None:
super(LROPagingOperation, self).__init__(
yaml_data,
name,
description,
url,
method,
multipart,
api_versions,
requests,
summary,
parameters,
multiple_media_type_parameters,
responses,
exceptions,
want_description_docstring,
want_tracing,
override_success_response_to_200=True
)
def imports(self, code_model, async_mode: bool) -> FileImport:
lro_imports = LROOperation.imports(self, code_model, async_mode)
paging_imports = PagingOperation.imports(self, code_model, async_mode)
file_import = lro_imports
file_import.merge(paging_imports)
return file_import

Просмотреть файл

@ -125,25 +125,26 @@ class ObjectSchema(BaseSchema): # pylint: disable=too-many-instance-attributes
# checking to see if this is a polymorphic class
subtype_map = None
if yaml_data.get("discriminator"):
subtype_map = {
children_yaml["discriminatorValue"]: children_yaml["language"][
"python"
]["name"]
for children_yaml in yaml_data["discriminator"][
"immediate"
].values()
}
subtype_map = {}
# map of discriminator value to child's name
for children_yaml in yaml_data["discriminator"]["immediate"].values():
subtype_map[children_yaml["discriminatorValue"]] = children_yaml["language"]["python"]["name"]
if yaml_data.get("properties"):
properties += [
Property.from_yaml(p, has_additional_properties=len(properties) > 0, **kwargs)
for p in yaml_data["properties"]
]
# this is to ensure that the attribute map type and property type are generated correctly
description = yaml_data["language"]["python"]["description"]
is_exception = False
exceptions_set = kwargs.pop("exceptions_set", None)
if exceptions_set and id(yaml_data) in exceptions_set:
is_exception = True
if exceptions_set:
if id(yaml_data) in exceptions_set:
is_exception = True
self.yaml_data = yaml_data
self.name = name

Просмотреть файл

@ -15,6 +15,7 @@ from .base_schema import BaseSchema
from .schema_request import SchemaRequest
from .object_schema import ObjectSchema
from .constant_schema import ConstantSchema
from .list_schema import ListSchema
_LOGGER = logging.getLogger(__name__)
@ -22,6 +23,7 @@ _LOGGER = logging.getLogger(__name__)
T = TypeVar('T')
OrderedSet = Dict[T, None]
_M4_HEADER_PARAMETERS = ["content_type", "accept"]
def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
response_media_types: OrderedSet[str] = {}
@ -29,7 +31,7 @@ def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
json_media_types = [media_type for media_type in media_types if "json" in media_type]
xml_media_types = [media_type for media_type in media_types if "xml" in media_type]
if sorted(json_media_types + xml_media_types) != sorted(media_types):
if not sorted(json_media_types + xml_media_types) == sorted(media_types):
raise ValueError("The non-binary responses with schemas of {self.name} have incorrect json or xml mime types")
if json_media_types:
if "application/json" in json_media_types:
@ -43,53 +45,24 @@ def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
response_media_types[xml_media_types[0]] = None
return response_media_types
def _remove_multiple_content_type_parameters(parameters: List[Parameter]) -> List[Parameter]:
content_type_params = [p for p in parameters if p.serialized_name == "content_type"]
remaining_params = [p for p in parameters if p.serialized_name != "content_type"]
json_content_type_param = [p for p in content_type_params if p.yaml_data["schema"]["type"] == "constant"]
if json_content_type_param:
remaining_params.append(json_content_type_param[0])
else:
remaining_params.append(content_type_params[0])
def _remove_multiple_m4_header_parameters(parameters: List[Parameter]) -> List[Parameter]:
m4_header_params_in_schema = {
k: [p for p in parameters if p.serialized_name == k]
for k in _M4_HEADER_PARAMETERS
}
remaining_params = [p for p in parameters if p.serialized_name not in _M4_HEADER_PARAMETERS]
json_m4_header_params = {
k: [p for p in m4_header_params_in_schema[k] if p.yaml_data["schema"]["type"] == "constant"]
for k in m4_header_params_in_schema
}
for k, v in json_m4_header_params.items():
if v:
remaining_params.append(v[0])
else:
remaining_params.append(m4_header_params_in_schema[k][0])
return remaining_params
def _accept_content_type_helper(responses: List[SchemaResponse]) -> OrderedSet[str]:
media_types = {
media_type: None for response in responses for media_type in response.media_types
}
if not media_types:
return media_types
if len(media_types.keys()) == 1:
# if there's just one media type, we return it
return media_types
# if not, we want to return them as binary_media_types + non_binary_media types
binary_media_types = {
media_type: None
for media_type in list(media_types.keys())
if "json" not in media_type and "xml" not in media_type
}
non_binary_schema_media_types = {
media_type: None for media_type in list(media_types.keys())
if "json" in media_type or "xml" in media_type
}
if all(response.binary for response in responses):
response_media_types = binary_media_types
elif all(response.schema for response in responses):
response_media_types = _non_binary_schema_media_types(
list(non_binary_schema_media_types.keys())
)
else:
non_binary_schema_media_types = _non_binary_schema_media_types(
list(non_binary_schema_media_types.keys())
)
response_media_types = binary_media_types
response_media_types.update(non_binary_schema_media_types)
return response_media_types
class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many-instance-attributes
"""Represent an operation.
"""
@ -101,6 +74,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
description: str,
url: str,
method: str,
multipart: bool,
api_versions: Set[str],
requests: List[SchemaRequest],
summary: Optional[str] = None,
@ -109,13 +83,14 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
responses: Optional[List[SchemaResponse]] = None,
exceptions: Optional[List[SchemaResponse]] = None,
want_description_docstring: bool = True,
want_tracing: bool = True,
want_tracing: bool = True
) -> None:
super().__init__(yaml_data)
self.name = name
self.description = description
self.url = url
self.method = method
self.multipart = multipart
self.api_versions = api_versions
self.requests = requests
self.summary = summary
@ -134,29 +109,11 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
def request_content_type(self) -> str:
return next(iter(
[
cast(ConstantSchema, p.schema).constant_value for p in self.parameters.constant
if p.serialized_name == "content_type"
p.schema.get_declaration(cast(ConstantSchema, p.schema).value)
for p in self.parameters.constant if p.serialized_name == "content_type"
]
))
@property
def accept_content_type(self) -> str:
if not self.has_response_body:
raise TypeError(
"There is an error in the code model we're being supplied. We're getting response media types " +
f"even though no response of {self.name} has a body"
)
response_content_types = _accept_content_type_helper(self.responses)
response_content_types.update(_accept_content_type_helper(self.exceptions))
if not response_content_types.keys():
raise TypeError(
f"Operation {self.name} has tried to get its accept_content_type even though it has no media types"
)
return ", ".join(list(response_content_types.keys()))
@property
def is_stream_request(self) -> bool:
"""Is the request is a stream, like an upload."""
@ -199,7 +156,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
if parameter.skip_url_encoding:
optional_parameters.append("skip_quote=True")
if parameter.style:
if parameter.style and not parameter.explode:
if parameter.style in [ParameterStyle.simple, ParameterStyle.form]:
div_char = ","
elif parameter.style in [ParameterStyle.spaceDelimited]:
@ -212,17 +169,32 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
raise ValueError(f"Do not support {parameter.style} yet")
optional_parameters.append(f"div='{div_char}'")
serialization_constraints = parameter.schema.serialization_constraints
optional_parameters += serialization_constraints if serialization_constraints else ""
if parameter.explode:
if not isinstance(parameter.schema, ListSchema):
raise ValueError("Got a explode boolean on a non-array schema")
serialization_schema = parameter.schema.element_type
else:
serialization_schema = parameter.schema
optional_parameters_string = "" if not optional_parameters else ", " + ", ".join(optional_parameters)
serialization_constraints = serialization_schema.serialization_constraints
if serialization_constraints:
optional_parameters += serialization_constraints
origin_name = parameter.full_serialized_name
return (
f"""self._serialize.{function_name}("{origin_name.lstrip('_')}", {origin_name}, """
+ f"""'{parameter.schema.serialization_type}'{optional_parameters_string})"""
)
parameters = [
f'"{origin_name.lstrip("_")}"',
"q" if parameter.explode else origin_name,
f"'{serialization_schema.serialization_type}'",
*optional_parameters
]
parameters_line = ', '.join(parameters)
serialize_line = f'self._serialize.{function_name}({parameters_line})'
if parameter.explode:
return f"[{serialize_line} if q is not None else '' for q in {origin_name}]"
return serialize_line
@property
def serialization_context(self) -> str:
@ -323,7 +295,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
for request in yaml_data["requests"]:
for yaml in request.get("parameters", []):
parameter = Parameter.from_yaml(yaml)
if yaml["language"]["python"]["name"] == "content_type":
if yaml["language"]["python"]["name"] in _M4_HEADER_PARAMETERS:
parameter.is_kwarg = True
parameters.append(parameter)
elif multiple_requests:
@ -332,7 +304,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
parameters.append(parameter)
if multiple_requests:
parameters = _remove_multiple_content_type_parameters(parameters)
parameters = _remove_multiple_m4_header_parameters(parameters)
chosen_parameter = multiple_media_type_parameters[0]
# binary body parameters are required, while object
@ -346,10 +318,9 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
parameters.append(chosen_parameter)
if multiple_media_type_parameters:
body_parameters_name_set = {
body_parameters_name_set = set(
p.serialized_name for p in multiple_media_type_parameters
}
)
if len(body_parameters_name_set) > 1:
raise ValueError(
f"The body parameter with multiple media types has different names: {body_parameters_name_set}"
@ -374,22 +345,13 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
description=yaml_data["language"]["python"]["description"],
url=first_request["protocol"]["http"]["path"],
method=first_request["protocol"]["http"]["method"],
api_versions={
value_dict["version"] for value_dict in yaml_data["apiVersions"]
},
requests=[
SchemaRequest.from_yaml(yaml) for yaml in yaml_data["requests"]
],
multipart=first_request["protocol"]["http"].get("multipart", False),
api_versions=set(value_dict["version"] for value_dict in yaml_data["apiVersions"]),
requests=[SchemaRequest.from_yaml(yaml) for yaml in yaml_data["requests"]],
summary=yaml_data["language"]["python"].get("summary"),
parameters=parameters,
multiple_media_type_parameters=multiple_media_type_parameters,
responses=[
SchemaResponse.from_yaml(yaml)
for yaml in yaml_data.get("responses", [])
],
exceptions=[
SchemaResponse.from_yaml(yaml)
for yaml in yaml_data.get("exceptions", [])
if "schema" in yaml
],
responses=[SchemaResponse.from_yaml(yaml) for yaml in yaml_data.get("responses", [])],
# Exception with no schema means default exception, we don't store them
exceptions=[SchemaResponse.from_yaml(yaml) for yaml in yaml_data.get("exceptions", []) if "schema" in yaml],
)

Просмотреть файл

@ -10,6 +10,7 @@ from .base_model import BaseModel
from .operation import Operation
from .lro_operation import LROOperation
from .paging_operation import PagingOperation
from .lro_paging_operation import LROPagingOperation
from .imports import FileImport, ImportType
@ -36,14 +37,21 @@ class OperationGroup(BaseModel):
self.operations = operations
self.api_versions = api_versions
@staticmethod
def imports(async_mode: bool, has_schemas: bool) -> FileImport:
def imports(self, async_mode: bool, has_schemas: bool) -> FileImport:
file_import = FileImport()
file_import.add_import("logging", ImportType.STDLIB)
file_import.add_from_import("azure.functions", "HttpRequest",
ImportType.AZURECORE)
file_import.add_from_import("azure.functions", "HttpResponse",
ImportType.AZURECORE)
file_import.add_from_import("azure.core.exceptions", "ResourceNotFoundError", ImportType.AZURECORE)
file_import.add_from_import("azure.core.exceptions", "ResourceExistsError", ImportType.AZURECORE)
for operation in self.operations:
file_import.merge(operation.imports(self.code_model, async_mode))
if self.code_model.options["tracing"]:
if async_mode:
file_import.add_from_import(
"azure.core.tracing.decorator_async", "distributed_trace_async", ImportType.AZURECORE,
)
else:
file_import.add_from_import(
"azure.core.tracing.decorator", "distributed_trace", ImportType.AZURECORE,
)
if has_schemas:
if async_mode:
file_import.add_from_import("...", "models", ImportType.LOCAL)
@ -75,9 +83,13 @@ class OperationGroup(BaseModel):
operations = []
api_versions: Set[str] = set()
for operation_yaml in yaml_data["operations"]:
if operation_yaml.get("extensions", {}).get("x-ms-long-running-operation"):
lro_operation = operation_yaml.get("extensions", {}).get("x-ms-long-running-operation")
paging_operation = operation_yaml.get("extensions", {}).get("x-ms-pageable")
if lro_operation and paging_operation:
operation = LROPagingOperation.from_yaml(operation_yaml)
elif lro_operation:
operation = LROOperation.from_yaml(operation_yaml)
elif operation_yaml.get("extensions", {}).get("x-ms-pageable"):
elif paging_operation:
operation = PagingOperation.from_yaml(operation_yaml)
else:
operation = Operation.from_yaml(operation_yaml)

Просмотреть файл

@ -4,62 +4,68 @@
# license information.
# --------------------------------------------------------------------------
import logging
from typing import Any, Dict, List, Optional, Set, cast
from typing import cast, Dict, List, Any, Optional, Set, Union
from .imports import FileImport, ImportType, TypingSection
from .object_schema import ObjectSchema
from .operation import Operation
from .parameter import Parameter
from .schema_request import SchemaRequest
from .schema_response import SchemaResponse
from .schema_request import SchemaRequest
from .imports import ImportType, FileImport, TypingSection
from .object_schema import ObjectSchema
_LOGGER = logging.getLogger(__name__)
class PagingOperation(Operation):
def __init__(
self,
yaml_data: Dict[str, Any],
name: str,
description: str,
url: str,
method: str,
api_versions: Set[str],
requests: List[SchemaRequest],
summary: Optional[str] = None,
parameters: Optional[List[Parameter]] = None,
multiple_media_type_parameters: Optional[List[Parameter]] = None,
responses: Optional[List[SchemaResponse]] = None,
exceptions: Optional[List[SchemaResponse]] = None,
self,
yaml_data: Dict[str, Any],
name: str,
description: str,
url: str,
method: str,
multipart: bool,
api_versions: Set[str],
requests: List[SchemaRequest],
summary: Optional[str] = None,
parameters: Optional[List[Parameter]] = None,
multiple_media_type_parameters: Optional[List[Parameter]] = None,
responses: Optional[List[SchemaResponse]] = None,
exceptions: Optional[List[SchemaResponse]] = None,
want_description_docstring: bool = True,
want_tracing: bool = True,
*,
override_success_response_to_200: bool = False
) -> None:
super(PagingOperation, self).__init__(
yaml_data,
name,
description,
url,
method,
api_versions,
requests,
summary,
parameters,
multiple_media_type_parameters,
responses,
exceptions
yaml_data,
name,
description,
url,
method,
multipart,
api_versions,
requests,
summary,
parameters,
multiple_media_type_parameters,
responses,
exceptions,
want_description_docstring,
want_tracing
)
self._item_name: str = yaml_data["extensions"]["x-ms-pageable"].get(
"itemName")
self._next_link_name: str = yaml_data["extensions"][
"x-ms-pageable"].get("nextLinkName")
self.operation_name: str = yaml_data["extensions"]["x-ms-pageable"].get(
"operationName")
self._item_name: str = yaml_data["extensions"]["x-ms-pageable"].get("itemName")
self._next_link_name: str = yaml_data["extensions"]["x-ms-pageable"].get("nextLinkName")
self.operation_name: str = yaml_data["extensions"]["x-ms-pageable"].get("operationName")
self.next_operation: Optional[Operation] = None
self.override_success_response_to_200 = override_success_response_to_200
def _get_response(self) -> SchemaResponse:
response = self.responses[0]
if not isinstance(response.schema, ObjectSchema):
raise ValueError(
"The response of a paging operation must be of type " +
f"ObjectSchema but {response.schema} is not"
"The response of a paging operation must be of type " +
f"ObjectSchema but {response.schema} is not"
)
return response
@ -70,24 +76,26 @@ class PagingOperation(Operation):
for prop in response_schema.properties:
if prop.original_swagger_name == rest_api_name:
return prop.name
raise ValueError("While scanning x-ms-pageable, was unable to find "
f"{log_name}:{rest_api_name} in model"
f" {response_schema.name}")
raise ValueError(
"While scanning x-ms-pageable, was unable to find "
+ f"{log_name}:{rest_api_name} in model {response_schema.name}"
)
@property
def item_name(self) -> str:
if self._item_name is None:
# Default value. I still check if I find it, so I can do a nice
# message.
# Default value. I still check if I find it,
# so I can do a nice message.
item_name = "value"
try:
return self._find_python_name(item_name, "itemName")
except ValueError:
response = self._get_response()
raise ValueError("While scanning x-ms-pageable, itemName was "
"not defined and object "
f"{response.schema.name} has no array "
"called 'value'")
raise ValueError(
"While scanning x-ms-pageable, itemName was not defined and"
" object" + f" {response.schema.name} has no array called "
"'value'"
)
return self._find_python_name(self._item_name, "itemName")
@property
@ -99,32 +107,31 @@ class PagingOperation(Operation):
@property
def has_optional_return_type(self) -> bool:
"""A paging will never have an optional return type, we will always
return ItemPaged[return type]"""
"""A paging will never have an optional return type, we will always return ItemPaged[return type]"""
return False
@property
def success_status_code(self) -> List[Union[str, int]]:
"""The list of all successfull status code.
"""
if self.override_success_response_to_200:
return [200]
return super(PagingOperation, self).success_status_code
def imports(self, code_model, async_mode: bool) -> FileImport:
file_import = super(PagingOperation, self).imports(code_model,
async_mode)
file_import = super(PagingOperation, self).imports(code_model, async_mode)
if async_mode:
file_import.add_from_import("azure.core.async_paging",
"AsyncItemPaged", ImportType.AZURECORE)
file_import.add_from_import("azure.core.async_paging", "AsyncList",
ImportType.AZURECORE)
file_import.add_from_import("typing", "AsyncIterable",
ImportType.STDLIB,
TypingSection.CONDITIONAL)
file_import.add_from_import("azure.core.async_paging", "AsyncItemPaged", ImportType.AZURECORE)
file_import.add_from_import("azure.core.async_paging", "AsyncList", ImportType.AZURECORE)
file_import.add_from_import("typing", "AsyncIterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
else:
file_import.add_from_import("azure.core.paging", "ItemPaged",
ImportType.AZURECORE)
file_import.add_from_import("typing", "Iterable", ImportType.STDLIB,
TypingSection.CONDITIONAL)
file_import.add_from_import("azure.core.paging", "ItemPaged", ImportType.AZURECORE)
file_import.add_from_import("typing", "Iterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
if code_model.options["tracing"]:
file_import.add_from_import(
"azure.core.tracing.decorator", "distributed_trace",
ImportType.AZURECORE,
"azure.core.tracing.decorator", "distributed_trace", ImportType.AZURECORE,
)
return file_import

Просмотреть файл

@ -37,6 +37,7 @@ class ParameterStyle(Enum):
json = "json"
binary = "binary"
xml = "xml"
multipart = "multipart"
class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
@ -54,6 +55,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
constraints: List[Any],
target_property_name: Optional[Union[int, str]] = None, # first uses id as placeholder
style: Optional[ParameterStyle] = None,
explode: Optional[bool] = False,
*,
flattened: bool = False,
grouped_by: Optional["Parameter"] = None,
@ -72,6 +74,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
self.constraints = constraints
self.target_property_name = target_property_name
self.style = style
self.explode = explode
self.flattened = flattened
self.grouped_by = grouped_by
self.original_parameter = original_parameter
@ -131,7 +134,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
default_value_declaration = "None"
else:
if isinstance(self.schema, ConstantSchema):
default_value = self.schema.constant_value
default_value = self.schema.get_declaration(self.schema.value)
default_value_declaration = default_value
else:
default_value = self.schema.default_value
@ -194,6 +197,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
constraints=[], # FIXME constraints
target_property_name=id(yaml_data["targetProperty"]) if yaml_data.get("targetProperty") else None,
style=ParameterStyle(http_protocol["style"]) if "style" in http_protocol else None,
explode=http_protocol.get("explode", False),
grouped_by=yaml_data.get("groupedBy", None),
original_parameter=yaml_data.get("originalParameter", None),
flattened=yaml_data.get("flattened", False),

Просмотреть файл

@ -56,11 +56,11 @@ class ParameterList(MutableSequence):
return self.has_any_location(ParameterLocation.Body)
@property
def body(self) -> Parameter:
def body(self) -> List[Parameter]:
if not self.has_body:
raise ValueError("Can't get body parameter")
# Should we check if there is two body? Modeler role right?
return self.get_from_location(ParameterLocation.Body)[0]
return self.get_from_location(ParameterLocation.Body)
@property
def path(self) -> List[Parameter]:
@ -140,5 +140,5 @@ class ParameterList(MutableSequence):
for param in parameters if param.target_property_name
]
)
object_schema = cast(ObjectSchema, self.body.schema)
return f"{self.body.serialized_name} = models.{object_schema.name}({parameter_string})"
object_schema = cast(ObjectSchema, self.body[0].schema)
return f"{self.body[0].serialized_name} = models.{object_schema.name}({parameter_string})"

Просмотреть файл

@ -19,7 +19,7 @@ class RawString(object):
self.string = string
def __repr__(self) -> str:
return f"r'{self.string}'"
return "r'{}'".format(self.string.replace('\'', '\\\''))
class PrimitiveSchema(BaseSchema):

Просмотреть файл

@ -4,8 +4,12 @@
# license information.
# --------------------------------------------------------------------------
from jinja2 import Environment
from autorest.codegen.serializers.azure_functions.python.import_serializer import FileImportSerializer, TypingSection
from ..models import FileImport, ImportType, CodeModel, CredentialSchema
from autorest.codegen.models import TypingSection
from autorest.codegen.serializers.azure_functions.python.import_serializer \
import \
FileImportSerializer
from ..models import FileImport, ImportType, CodeModel, TokenCredentialSchema
class GeneralSerializer:
@ -24,9 +28,9 @@ class GeneralSerializer:
def _correct_credential_parameter(self):
credential_param = [
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, TokenCredentialSchema)
][0]
credential_param.schema = CredentialSchema(async_mode=self.async_mode)
credential_param.schema = TokenCredentialSchema(async_mode=self.async_mode)
def serialize_service_client_file(self) -> str:
def _service_client_imports() -> FileImport:
@ -37,7 +41,10 @@ class GeneralSerializer:
template = self.env.get_template("service_client.py.jinja2")
if self.code_model.options['credential']:
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
):
self._correct_credential_parameter()
return template.render(
@ -59,6 +66,8 @@ class GeneralSerializer:
file_import.add_from_import(".._version" if async_mode else "._version", "VERSION", ImportType.LOCAL)
for gp in self.code_model.global_parameters:
file_import.merge(gp.imports())
if self.code_model.options["azure_arm"]:
file_import.add_from_import("azure.mgmt.core.policies", "ARMHttpLoggingPolicy", ImportType.AZURECORE)
return file_import
package_name = self.code_model.options['package_name']
@ -66,7 +75,10 @@ class GeneralSerializer:
package_name = package_name[len("azure-"):]
sdk_moniker = package_name if package_name else self.code_model.class_name.lower()
if self.code_model.options['credential']:
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
):
self._correct_credential_parameter()
template = self.env.get_template("config.py.jinja2")

Просмотреть файл

@ -5,25 +5,22 @@
# --------------------------------------------------------------------------
import copy
import json
from typing import List, Optional, Set, Tuple, Dict
from typing import Dict, List, Optional, Set, Tuple
from jinja2 import Environment
from ..models import (
CodeModel,
Operation,
OperationGroup,
LROOperation,
PagingOperation,
CredentialSchema,
ParameterList,
TypingSection,
ImportType
)
from ..models import (CodeModel, ImportType, LROOperation, Operation,
OperationGroup, PagingOperation, ParameterList,
TokenCredentialSchema, TypingSection)
def _correct_credential_parameter(global_parameters: ParameterList, async_mode: bool) -> None:
credential_param = [
gp for gp in global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
gp for gp in global_parameters.parameters if
isinstance(gp.schema, TokenCredentialSchema)
][0]
credential_param.schema = CredentialSchema(async_mode=async_mode)
credential_param.schema = TokenCredentialSchema(async_mode=async_mode)
def _json_serialize_imports(
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
@ -32,9 +29,9 @@ def _json_serialize_imports(
return None
json_serialize_imports = {}
# need to make name_import set -> list to make the dictionary json serializable
# not using an OrderedDict since we're iterating through a set and the order there varies
# going to sort the list instead
# need to make name_import () -> [] to make the dictionary json serializable
# not using an OrderedDict since we're iterating through a set and the order
# there varies going to sort the list instead
for typing_section_key, typing_section_value in imports.items():
json_import_type_dictionary = {}
@ -107,11 +104,19 @@ class MetadataSerializer:
# In this case, we need two copies of the credential global parameter
# for typing purposes.
async_global_parameters = self.code_model.global_parameters
if self.code_model.options['credential']:
# this ensures that the CredentialSchema showing up in the list of code model's global parameters
# is sync. This way we only have to make a copy for an async_credential
_correct_credential_parameter(self.code_model.global_parameters, False)
async_global_parameters = self._make_async_copy_of_global_parameters()
if (
self.code_model.options['credential'] and
self.code_model.options[
'credential_default_policy_type'] ==
"BearerTokenCredentialPolicy"
):
# this ensures that the TokenCredentialSchema showing up in the
# list of code model's global parameters is sync.
# This way we only have to make a copy for an async_credential
_correct_credential_parameter(self.code_model.global_parameters,
False)
async_global_parameters = \
self._make_async_copy_of_global_parameters()
template = self.env.get_template("metadata.json.jinja2")
return template.render(

Просмотреть файл

@ -5,13 +5,14 @@
# --------------------------------------------------------------------------
from jinja2 import Environment
from autorest.codegen.serializers.azure_functions.python.import_serializer import FileImportSerializer
from ..models import LROOperation, PagingOperation, CodeModel, OperationGroup
from .azure_functions.python.import_serializer import FileImportSerializer
from ..models import CodeModel, LROOperation, OperationGroup, PagingOperation
class OperationGroupSerializer:
def __init__(
self, code_model: CodeModel, env: Environment, operation_group: OperationGroup, async_mode: bool
self, code_model: CodeModel, env: Environment,
operation_group: OperationGroup, async_mode: bool
) -> None:
self.code_model = code_model
self.env = env
@ -25,18 +26,25 @@ class OperationGroupSerializer:
def _is_paging(operation):
return isinstance(operation, PagingOperation)
operation_group_template = self.env.get_template("operations_container.py.jinja2")
operation_group_template = self.env.get_template(
"operations_container.py.jinja2")
if self.operation_group.is_empty_operation_group:
operation_group_template = self.env.get_template("operations_container_mixin.py.jinja2")
operation_group_template = self.env.get_template(
"operations_container_mixin.py.jinja2")
return operation_group_template.render(
code_model=self.code_model,
operation_group=self.operation_group,
imports=FileImportSerializer(
self.operation_group.imports(self.async_mode, bool(self.code_model.schemas)),
is_python_3_file=self.async_mode
),
async_mode=self.async_mode,
is_lro=_is_lro,
is_paging=_is_paging,
code_model=self.code_model,
operation_group=self.operation_group,
imports=FileImportSerializer(
self.operation_group.imports(
self.async_mode,
bool(
self.code_model.schemas or
self.code_model.enums)
),
is_python_3_file=self.async_mode
),
async_mode=self.async_mode,
is_lro=_is_lro,
is_paging=_is_paging,
)

Просмотреть файл

@ -17,7 +17,7 @@ _LOGGER = logging.getLogger(__name__)
@dispatcher.add_method
def GetPluginNames():
return ["codegen", "m2r", "namer", "namercsharp", "multiapiscript"]
return ["codegen", "m2r", "namer"]
@dispatcher.add_method
@ -34,12 +34,8 @@ def Process(plugin_name: str, session_id: str) -> bool:
from ..m2r import M2R as PluginToLoad
elif plugin_name == "namer":
from ..namer import Namer as PluginToLoad # type: ignore
elif plugin_name == "namercsharp":
from ..namer.csharp import Namer as PluginToLoad # type: ignore
elif plugin_name == "codegen":
from ..codegen import CodeGenerator as PluginToLoad # type: ignore
elif plugin_name == "multiapiscript":
from ..multiapi import MultiApiScriptPlugin as PluginToLoad # type: ignore
else:
_LOGGER.fatal("Unknown plugin name %s", plugin_name)
raise RuntimeError(f"Unknown plugin name {plugin_name}")

Просмотреть файл

@ -82,7 +82,7 @@ class NameConverter:
def _convert_enum_schema(schema: Dict[str, Any]) -> None:
NameConverter._convert_language_default_pascal_case(schema)
for choice in schema["choices"]:
NameConverter._convert_language_default_python_case(choice, pad_string=PadType.Enum)
NameConverter._convert_language_default_python_case(choice, pad_string=PadType.Enum, all_upper=True)
@staticmethod
def _convert_object_schema(schema: Dict[str, Any]) -> None:
@ -97,9 +97,21 @@ class NameConverter:
for prop in schema.get("properties", []):
NameConverter._convert_language_default_python_case(schema=prop, pad_string=PadType.Property)
@staticmethod
def _is_schema_an_m4_header_parameter(schema_name: str, schema: Dict[str, Any]) -> bool:
m4_header_parameters = ["content_type", "accept"]
return (
schema_name in m4_header_parameters and
schema.get('protocol', {}).get('http', {}).get('in', {}) == 'header'
)
@staticmethod
def _convert_language_default_python_case(
schema: Dict[str, Any], *, pad_string: Optional[PadType] = None, convert_name: bool = False
schema: Dict[str, Any],
*,
pad_string: Optional[PadType] = None,
convert_name: bool = False,
all_upper: bool = False
) -> None:
if not schema.get("language") or schema["language"].get("python"):
return
@ -107,17 +119,19 @@ class NameConverter:
schema_name = schema['language']['default']['name']
schema_python_name = schema['language']['python']['name']
if not(
schema_name == 'content_type' and
schema.get('protocol') and
schema['protocol'].get('http') and
schema['protocol']['http']['in'] == "header"
if not NameConverter._is_schema_an_m4_header_parameter(
schema_name, schema
):
# only escaping name if it's not a content_type header parameter
schema_python_name = NameConverter._to_valid_python_name(
name=schema_name, pad_string=pad_string, convert_name=convert_name
)
schema['language']['python']['name'] = schema_python_name.lower()
# need to add the lower in case certain words, like LRO, are overriden to
# always return LRO. Without .lower(), for example, begin_lro would be
# begin_LRO
schema['language']['python']['name'] = (
schema_python_name.upper() if all_upper else schema_python_name.lower()
)
schema_description = schema["language"]["default"]["description"].strip()
if pad_string == PadType.Method and not schema_description and not schema["language"]["default"].get("summary"):

Просмотреть файл

@ -100,6 +100,7 @@ reserved_words = {
"self",
# these are kwargs we've reserved for our autorest generated operations
"content_type",
"accept",
"cls",
"polling",
"continuation_token", # for LRO calls