Bringing in new changes from Autorest.python
This commit is contained in:
Родитель
2b239b3a3a
Коммит
4b8e0a34eb
|
@ -8,7 +8,6 @@ __pycache__/
|
|||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
|
@ -27,6 +26,7 @@ share/python-wheels/
|
|||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
package-lock.json
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
|
@ -111,6 +111,7 @@ ENV/
|
|||
env.bak/
|
||||
venv.bak/
|
||||
py3env/
|
||||
env/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
@ -155,4 +156,11 @@ autorest-azurefunction*gz
|
|||
autorest-azure-functions*gz
|
||||
|
||||
# NPMrc files
|
||||
.npmrc
|
||||
.npmrc
|
||||
|
||||
# autorest default folder
|
||||
generated
|
||||
**/code-model-v4-no-tags.yaml
|
||||
|
||||
# Generated test folders
|
||||
test/services/*/_generated
|
|
@ -5,7 +5,8 @@
|
|||
# --------------------------------------------------------------------------
|
||||
from typing import Any, Dict
|
||||
from .base_model import BaseModel
|
||||
from .code_model import CodeModel, CredentialSchema
|
||||
from .code_model import CodeModel
|
||||
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
|
||||
from .object_schema import ObjectSchema
|
||||
from .dictionary_schema import DictionarySchema
|
||||
from .list_schema import ListSchema
|
||||
|
@ -13,6 +14,7 @@ from .primitive_schemas import get_primitive_schema, AnySchema, PrimitiveSchema
|
|||
from .enum_schema import EnumSchema
|
||||
from .base_schema import BaseSchema
|
||||
from .constant_schema import ConstantSchema
|
||||
from .credential_schema import CredentialSchema
|
||||
from .imports import FileImport, ImportType, TypingSection
|
||||
from .lro_operation import LROOperation
|
||||
from .paging_operation import PagingOperation
|
||||
|
@ -25,11 +27,12 @@ from .parameter_list import ParameterList
|
|||
|
||||
|
||||
__all__ = [
|
||||
"AzureKeyCredentialSchema",
|
||||
"BaseModel",
|
||||
"BaseSchema",
|
||||
"CodeModel",
|
||||
"CredentialSchema",
|
||||
"ConstantSchema",
|
||||
"CredentialSchema",
|
||||
"ObjectSchema",
|
||||
"DictionarySchema",
|
||||
"ListSchema",
|
||||
|
@ -45,7 +48,8 @@ __all__ = [
|
|||
"ParameterList",
|
||||
"OperationGroup",
|
||||
"Property",
|
||||
"SchemaResponse"
|
||||
"SchemaResponse",
|
||||
"TokenCredentialSchema",
|
||||
]
|
||||
|
||||
def _generate_as_object_schema(yaml_data: Dict[str, Any]) -> bool:
|
||||
|
|
|
@ -46,6 +46,8 @@ class BaseSchema(BaseModel, ABC):
|
|||
attrs_list.append(f"'prefix': '{self.xml_metadata['prefix']}'")
|
||||
if self.xml_metadata.get("namespace", False):
|
||||
attrs_list.append(f"'ns': '{self.xml_metadata['namespace']}'")
|
||||
if self.xml_metadata.get("text"):
|
||||
attrs_list.append("'text': True")
|
||||
return ", ".join(attrs_list)
|
||||
|
||||
def imports(self) -> FileImport: # pylint: disable=no-self-use
|
||||
|
|
|
@ -28,8 +28,6 @@ class Client:
|
|||
file_import.add_from_import("msrest", "Deserializer", ImportType.AZURECORE)
|
||||
file_import.add_from_import("typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL)
|
||||
|
||||
# if code_model.options["credential"]:
|
||||
# file_import.add_from_import("azure.core.credentials", "TokenCredential", ImportType.AZURECORE)
|
||||
any_optional_gp = any(not gp.required for gp in code_model.global_parameters)
|
||||
|
||||
if any_optional_gp or code_model.base_url:
|
||||
|
@ -43,4 +41,10 @@ class Client:
|
|||
file_import.add_from_import(
|
||||
"azure.core", Client.pipeline_class(code_model, async_mode), ImportType.AZURECORE
|
||||
)
|
||||
|
||||
if not code_model.sorted_schemas:
|
||||
# in this case, we have client_models = {} in the service client, which needs a type annotation
|
||||
# this import will always be commented, so will always add it to the typing section
|
||||
file_import.add_from_import("typing", "Dict", ImportType.STDLIB, TypingSection.TYPING)
|
||||
|
||||
return file_import
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
# --------------------------------------------------------------------------
|
||||
from itertools import chain
|
||||
import logging
|
||||
from typing import cast, List, Dict, Optional, Any, Set
|
||||
from typing import cast, List, Dict, Optional, Any, Set, Union
|
||||
|
||||
from .base_schema import BaseSchema
|
||||
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
|
||||
from .enum_schema import EnumSchema
|
||||
from .object_schema import ObjectSchema
|
||||
from .operation_group import OperationGroup
|
||||
|
@ -17,7 +18,6 @@ from .paging_operation import PagingOperation
|
|||
from .parameter import Parameter, ParameterLocation
|
||||
from .client import Client
|
||||
from .parameter_list import ParameterList
|
||||
from .imports import FileImport, ImportType, TypingSection
|
||||
from .schema_response import SchemaResponse
|
||||
from .property import Property
|
||||
from .primitive_schemas import IOSchema
|
||||
|
@ -26,50 +26,6 @@ from .primitive_schemas import IOSchema
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialSchema(BaseSchema):
|
||||
def __init__(self, async_mode) -> None: # pylint: disable=super-init-not-called
|
||||
self.async_mode = async_mode
|
||||
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
|
||||
self.sync_type = "~azure.core.credentials.TokenCredential"
|
||||
self.default_value = None
|
||||
|
||||
@property
|
||||
def serialization_type(self) -> str:
|
||||
if self.async_mode:
|
||||
return self.async_type
|
||||
return self.sync_type
|
||||
|
||||
@property
|
||||
def docstring_type(self) -> str:
|
||||
return self.serialization_type
|
||||
|
||||
@property
|
||||
def type_annotation(self) -> str:
|
||||
if self.async_mode:
|
||||
return '"AsyncTokenCredential"'
|
||||
return '"TokenCredential"'
|
||||
|
||||
@property
|
||||
def docstring_text(self) -> str:
|
||||
return "credential"
|
||||
|
||||
def imports(self) -> FileImport:
|
||||
file_import = FileImport()
|
||||
if self.async_mode:
|
||||
file_import.add_from_import(
|
||||
"azure.core.credentials_async", "AsyncTokenCredential",
|
||||
ImportType.AZURECORE,
|
||||
typing_section=TypingSection.TYPING
|
||||
)
|
||||
else:
|
||||
file_import.add_from_import(
|
||||
"azure.core.credentials", "TokenCredential",
|
||||
ImportType.AZURECORE,
|
||||
typing_section=TypingSection.TYPING
|
||||
)
|
||||
return file_import
|
||||
|
||||
|
||||
class CodeModel: # pylint: disable=too-many-instance-attributes
|
||||
"""Holds all of the information we have parsed out of the yaml file. The CodeModel is what gets
|
||||
serialized by the serializers.
|
||||
|
@ -168,7 +124,11 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
|
|||
:return: None
|
||||
:rtype: None
|
||||
"""
|
||||
credential_schema = CredentialSchema(async_mode=False)
|
||||
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
|
||||
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
|
||||
credential_schema = TokenCredentialSchema(async_mode=False)
|
||||
else:
|
||||
credential_schema = AzureKeyCredentialSchema()
|
||||
credential_parameter = Parameter(
|
||||
yaml_data={},
|
||||
schema=credential_schema,
|
||||
|
@ -191,6 +151,7 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
|
|||
description="",
|
||||
url=operation.url,
|
||||
method=operation.method,
|
||||
multipart=operation.multipart,
|
||||
api_versions=operation.api_versions,
|
||||
parameters=operation.parameters.parameters,
|
||||
requests=operation.requests,
|
||||
|
@ -246,17 +207,18 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
|
|||
def _add_properties_from_inheritance_helper(schema, properties) -> List[Property]:
|
||||
if not schema.base_models:
|
||||
return properties
|
||||
for base_model in schema.base_models:
|
||||
parent = cast(ObjectSchema, base_model)
|
||||
schema_property_names = [p.name for p in properties]
|
||||
chosen_parent_properties = [
|
||||
p for p in parent.properties
|
||||
if p.name not in schema_property_names
|
||||
]
|
||||
properties = (
|
||||
CodeModel._add_properties_from_inheritance_helper(parent, chosen_parent_properties) +
|
||||
properties
|
||||
)
|
||||
if schema.base_models:
|
||||
for base_model in schema.base_models:
|
||||
parent = cast(ObjectSchema, base_model)
|
||||
schema_property_names = [p.name for p in properties]
|
||||
chosen_parent_properties = [
|
||||
p for p in parent.properties
|
||||
if p.name not in schema_property_names
|
||||
]
|
||||
properties = (
|
||||
CodeModel._add_properties_from_inheritance_helper(parent, chosen_parent_properties) +
|
||||
properties
|
||||
)
|
||||
|
||||
return properties
|
||||
|
||||
|
@ -367,3 +329,11 @@ class CodeModel: # pylint: disable=too-many-instance-attributes
|
|||
raise ValueError("You are missing a parameter that has multiple media types")
|
||||
chosen_parameter.multiple_media_types_type_annot = f"Union[{type_annot}]"
|
||||
chosen_parameter.multiple_media_types_docstring_type = docstring_type
|
||||
|
||||
@property
|
||||
def has_lro_operations(self) -> bool:
|
||||
return any([
|
||||
isinstance(operation, LROOperation)
|
||||
for operation_group in self.operation_groups
|
||||
for operation in operation_group.operations
|
||||
])
|
||||
|
|
|
@ -30,12 +30,11 @@ class ConstantSchema(BaseSchema):
|
|||
self.schema = schema
|
||||
|
||||
def get_declaration(self, value: Any):
|
||||
raise TypeError("Should not call get_declaration on a ConstantSchema. Use the constant_value property instead")
|
||||
|
||||
@property
|
||||
def constant_value(self) -> str:
|
||||
"""This string is used directly in template, as-is
|
||||
"""
|
||||
if value != self.value:
|
||||
_LOGGER.warning(
|
||||
"Passed in value of %s differs from constant value of %s. Choosing constant value",
|
||||
str(value), str(self.value)
|
||||
)
|
||||
if self.value is None:
|
||||
return "None"
|
||||
return self.schema.get_declaration(self.value)
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from .base_schema import BaseSchema
|
||||
from .imports import FileImport, ImportType, TypingSection
|
||||
|
||||
|
||||
class CredentialSchema(BaseSchema):
|
||||
def __init__(self) -> None: # pylint: disable=super-init-not-called
|
||||
self.default_value = None
|
||||
|
||||
@property
|
||||
def docstring_type(self) -> str:
|
||||
return self.serialization_type
|
||||
|
||||
@property
|
||||
def docstring_text(self) -> str:
|
||||
return "credential"
|
||||
|
||||
@property
|
||||
def serialization_type(self) -> str:
|
||||
# this property is added, because otherwise pylint says that
|
||||
# abstract serialization_type in BaseSchema is not overridden
|
||||
pass
|
||||
|
||||
|
||||
class AzureKeyCredentialSchema(CredentialSchema):
|
||||
|
||||
@property
|
||||
def serialization_type(self) -> str:
|
||||
return "~azure.core.credentials.AzureKeyCredential"
|
||||
|
||||
@property
|
||||
def type_annotation(self) -> str:
|
||||
return "AzureKeyCredential"
|
||||
|
||||
def imports(self) -> FileImport:
|
||||
file_import = FileImport()
|
||||
file_import.add_from_import(
|
||||
"azure.core.credentials",
|
||||
"AzureKeyCredential",
|
||||
ImportType.AZURECORE,
|
||||
typing_section=TypingSection.CONDITIONAL
|
||||
)
|
||||
return file_import
|
||||
|
||||
|
||||
class TokenCredentialSchema(CredentialSchema):
|
||||
def __init__(self, async_mode) -> None:
|
||||
super(TokenCredentialSchema, self).__init__()
|
||||
self.async_mode = async_mode
|
||||
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
|
||||
self.sync_type = "~azure.core.credentials.TokenCredential"
|
||||
|
||||
@property
|
||||
def serialization_type(self) -> str:
|
||||
if self.async_mode:
|
||||
return self.async_type
|
||||
return self.sync_type
|
||||
|
||||
@property
|
||||
def type_annotation(self) -> str:
|
||||
if self.async_mode:
|
||||
return '"AsyncTokenCredential"'
|
||||
return '"TokenCredential"'
|
||||
|
||||
|
||||
def imports(self) -> FileImport:
|
||||
file_import = FileImport()
|
||||
if self.async_mode:
|
||||
file_import.add_from_import(
|
||||
"azure.core.credentials_async", "AsyncTokenCredential",
|
||||
ImportType.AZURECORE,
|
||||
typing_section=TypingSection.TYPING
|
||||
)
|
||||
else:
|
||||
file_import.add_from_import(
|
||||
"azure.core.credentials", "TokenCredential",
|
||||
ImportType.AZURECORE,
|
||||
typing_section=TypingSection.TYPING
|
||||
)
|
||||
return file_import
|
|
@ -13,7 +13,6 @@ class ImportType(str, Enum):
|
|||
AZURECORE = "azurecore"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class TypingSection(str, Enum):
|
||||
REGULAR = "regular" # this import is always a typing import
|
||||
CONDITIONAL = "conditional" # is a typing import when we're dealing with files that py2 will use, else regular
|
||||
|
|
|
@ -24,6 +24,7 @@ class LROOperation(Operation):
|
|||
description: str,
|
||||
url: str,
|
||||
method: str,
|
||||
multipart: bool,
|
||||
api_versions: Set[str],
|
||||
requests: List[SchemaRequest],
|
||||
summary: Optional[str] = None,
|
||||
|
@ -32,7 +33,7 @@ class LROOperation(Operation):
|
|||
responses: Optional[List[SchemaResponse]] = None,
|
||||
exceptions: Optional[List[SchemaResponse]] = None,
|
||||
want_description_docstring: bool = True,
|
||||
want_tracing: bool = True,
|
||||
want_tracing: bool = True
|
||||
) -> None:
|
||||
super(LROOperation, self).__init__(
|
||||
yaml_data,
|
||||
|
@ -40,6 +41,7 @@ class LROOperation(Operation):
|
|||
description,
|
||||
url,
|
||||
method,
|
||||
multipart,
|
||||
api_versions,
|
||||
requests,
|
||||
summary,
|
||||
|
@ -65,10 +67,11 @@ class LROOperation(Operation):
|
|||
rt for rt in response_types if 200 in rt.status_codes
|
||||
]
|
||||
if not response_types_with_200_status_code:
|
||||
raise ValueError("Your swagger is invalid because you have "
|
||||
"multiple response schemas for LRO method "
|
||||
f"{self.python_name} and none of them have a "
|
||||
"200 status code.")
|
||||
raise ValueError(
|
||||
"Your swagger is invalid because you have multiple response"
|
||||
" schemas for LRO" + f" method {self.python_name} and "
|
||||
"none of them have a 200 status code."
|
||||
)
|
||||
response_type = response_types_with_200_status_code[0]
|
||||
|
||||
response_type_schema_name = cast(BaseSchema, response_type.schema).serialization_type
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, Dict, List, Set, Optional
|
||||
from .lro_operation import LROOperation
|
||||
from .paging_operation import PagingOperation
|
||||
from .imports import FileImport
|
||||
from .schema_request import SchemaRequest
|
||||
from .parameter import Parameter
|
||||
from .schema_response import SchemaResponse
|
||||
|
||||
class LROPagingOperation(PagingOperation, LROOperation):
|
||||
def __init__(
|
||||
self,
|
||||
yaml_data: Dict[str, Any],
|
||||
name: str,
|
||||
description: str,
|
||||
url: str,
|
||||
method: str,
|
||||
multipart: bool,
|
||||
api_versions: Set[str],
|
||||
requests: List[SchemaRequest],
|
||||
summary: Optional[str] = None,
|
||||
parameters: Optional[List[Parameter]] = None,
|
||||
multiple_media_type_parameters: Optional[List[Parameter]] = None,
|
||||
responses: Optional[List[SchemaResponse]] = None,
|
||||
exceptions: Optional[List[SchemaResponse]] = None,
|
||||
want_description_docstring: bool = True,
|
||||
want_tracing: bool = True
|
||||
) -> None:
|
||||
super(LROPagingOperation, self).__init__(
|
||||
yaml_data,
|
||||
name,
|
||||
description,
|
||||
url,
|
||||
method,
|
||||
multipart,
|
||||
api_versions,
|
||||
requests,
|
||||
summary,
|
||||
parameters,
|
||||
multiple_media_type_parameters,
|
||||
responses,
|
||||
exceptions,
|
||||
want_description_docstring,
|
||||
want_tracing,
|
||||
override_success_response_to_200=True
|
||||
)
|
||||
|
||||
def imports(self, code_model, async_mode: bool) -> FileImport:
|
||||
lro_imports = LROOperation.imports(self, code_model, async_mode)
|
||||
paging_imports = PagingOperation.imports(self, code_model, async_mode)
|
||||
|
||||
file_import = lro_imports
|
||||
file_import.merge(paging_imports)
|
||||
return file_import
|
|
@ -125,25 +125,26 @@ class ObjectSchema(BaseSchema): # pylint: disable=too-many-instance-attributes
|
|||
# checking to see if this is a polymorphic class
|
||||
subtype_map = None
|
||||
if yaml_data.get("discriminator"):
|
||||
subtype_map = {
|
||||
children_yaml["discriminatorValue"]: children_yaml["language"][
|
||||
"python"
|
||||
]["name"]
|
||||
for children_yaml in yaml_data["discriminator"][
|
||||
"immediate"
|
||||
].values()
|
||||
}
|
||||
subtype_map = {}
|
||||
# map of discriminator value to child's name
|
||||
for children_yaml in yaml_data["discriminator"]["immediate"].values():
|
||||
subtype_map[children_yaml["discriminatorValue"]] = children_yaml["language"]["python"]["name"]
|
||||
|
||||
if yaml_data.get("properties"):
|
||||
properties += [
|
||||
Property.from_yaml(p, has_additional_properties=len(properties) > 0, **kwargs)
|
||||
for p in yaml_data["properties"]
|
||||
]
|
||||
# this is to ensure that the attribute map type and property type are generated correctly
|
||||
|
||||
|
||||
|
||||
description = yaml_data["language"]["python"]["description"]
|
||||
is_exception = False
|
||||
exceptions_set = kwargs.pop("exceptions_set", None)
|
||||
if exceptions_set and id(yaml_data) in exceptions_set:
|
||||
is_exception = True
|
||||
if exceptions_set:
|
||||
if id(yaml_data) in exceptions_set:
|
||||
is_exception = True
|
||||
|
||||
self.yaml_data = yaml_data
|
||||
self.name = name
|
||||
|
|
|
@ -15,6 +15,7 @@ from .base_schema import BaseSchema
|
|||
from .schema_request import SchemaRequest
|
||||
from .object_schema import ObjectSchema
|
||||
from .constant_schema import ConstantSchema
|
||||
from .list_schema import ListSchema
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -22,6 +23,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
T = TypeVar('T')
|
||||
OrderedSet = Dict[T, None]
|
||||
|
||||
_M4_HEADER_PARAMETERS = ["content_type", "accept"]
|
||||
|
||||
def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
|
||||
response_media_types: OrderedSet[str] = {}
|
||||
|
@ -29,7 +31,7 @@ def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
|
|||
json_media_types = [media_type for media_type in media_types if "json" in media_type]
|
||||
xml_media_types = [media_type for media_type in media_types if "xml" in media_type]
|
||||
|
||||
if sorted(json_media_types + xml_media_types) != sorted(media_types):
|
||||
if not sorted(json_media_types + xml_media_types) == sorted(media_types):
|
||||
raise ValueError("The non-binary responses with schemas of {self.name} have incorrect json or xml mime types")
|
||||
if json_media_types:
|
||||
if "application/json" in json_media_types:
|
||||
|
@ -43,53 +45,24 @@ def _non_binary_schema_media_types(media_types: List[str]) -> OrderedSet[str]:
|
|||
response_media_types[xml_media_types[0]] = None
|
||||
return response_media_types
|
||||
|
||||
def _remove_multiple_content_type_parameters(parameters: List[Parameter]) -> List[Parameter]:
|
||||
content_type_params = [p for p in parameters if p.serialized_name == "content_type"]
|
||||
remaining_params = [p for p in parameters if p.serialized_name != "content_type"]
|
||||
json_content_type_param = [p for p in content_type_params if p.yaml_data["schema"]["type"] == "constant"]
|
||||
if json_content_type_param:
|
||||
remaining_params.append(json_content_type_param[0])
|
||||
else:
|
||||
remaining_params.append(content_type_params[0])
|
||||
def _remove_multiple_m4_header_parameters(parameters: List[Parameter]) -> List[Parameter]:
|
||||
m4_header_params_in_schema = {
|
||||
k: [p for p in parameters if p.serialized_name == k]
|
||||
for k in _M4_HEADER_PARAMETERS
|
||||
}
|
||||
remaining_params = [p for p in parameters if p.serialized_name not in _M4_HEADER_PARAMETERS]
|
||||
json_m4_header_params = {
|
||||
k: [p for p in m4_header_params_in_schema[k] if p.yaml_data["schema"]["type"] == "constant"]
|
||||
for k in m4_header_params_in_schema
|
||||
}
|
||||
for k, v in json_m4_header_params.items():
|
||||
if v:
|
||||
remaining_params.append(v[0])
|
||||
else:
|
||||
remaining_params.append(m4_header_params_in_schema[k][0])
|
||||
|
||||
return remaining_params
|
||||
|
||||
def _accept_content_type_helper(responses: List[SchemaResponse]) -> OrderedSet[str]:
|
||||
media_types = {
|
||||
media_type: None for response in responses for media_type in response.media_types
|
||||
}
|
||||
|
||||
if not media_types:
|
||||
return media_types
|
||||
|
||||
if len(media_types.keys()) == 1:
|
||||
# if there's just one media type, we return it
|
||||
return media_types
|
||||
# if not, we want to return them as binary_media_types + non_binary_media types
|
||||
binary_media_types = {
|
||||
media_type: None
|
||||
for media_type in list(media_types.keys())
|
||||
if "json" not in media_type and "xml" not in media_type
|
||||
}
|
||||
|
||||
non_binary_schema_media_types = {
|
||||
media_type: None for media_type in list(media_types.keys())
|
||||
if "json" in media_type or "xml" in media_type
|
||||
}
|
||||
if all(response.binary for response in responses):
|
||||
response_media_types = binary_media_types
|
||||
elif all(response.schema for response in responses):
|
||||
response_media_types = _non_binary_schema_media_types(
|
||||
list(non_binary_schema_media_types.keys())
|
||||
)
|
||||
else:
|
||||
non_binary_schema_media_types = _non_binary_schema_media_types(
|
||||
list(non_binary_schema_media_types.keys())
|
||||
)
|
||||
response_media_types = binary_media_types
|
||||
response_media_types.update(non_binary_schema_media_types)
|
||||
|
||||
return response_media_types
|
||||
|
||||
class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many-instance-attributes
|
||||
"""Represent an operation.
|
||||
"""
|
||||
|
@ -101,6 +74,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
description: str,
|
||||
url: str,
|
||||
method: str,
|
||||
multipart: bool,
|
||||
api_versions: Set[str],
|
||||
requests: List[SchemaRequest],
|
||||
summary: Optional[str] = None,
|
||||
|
@ -109,13 +83,14 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
responses: Optional[List[SchemaResponse]] = None,
|
||||
exceptions: Optional[List[SchemaResponse]] = None,
|
||||
want_description_docstring: bool = True,
|
||||
want_tracing: bool = True,
|
||||
want_tracing: bool = True
|
||||
) -> None:
|
||||
super().__init__(yaml_data)
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.multipart = multipart
|
||||
self.api_versions = api_versions
|
||||
self.requests = requests
|
||||
self.summary = summary
|
||||
|
@ -134,29 +109,11 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
def request_content_type(self) -> str:
|
||||
return next(iter(
|
||||
[
|
||||
cast(ConstantSchema, p.schema).constant_value for p in self.parameters.constant
|
||||
if p.serialized_name == "content_type"
|
||||
p.schema.get_declaration(cast(ConstantSchema, p.schema).value)
|
||||
for p in self.parameters.constant if p.serialized_name == "content_type"
|
||||
]
|
||||
))
|
||||
|
||||
@property
|
||||
def accept_content_type(self) -> str:
|
||||
if not self.has_response_body:
|
||||
raise TypeError(
|
||||
"There is an error in the code model we're being supplied. We're getting response media types " +
|
||||
f"even though no response of {self.name} has a body"
|
||||
)
|
||||
response_content_types = _accept_content_type_helper(self.responses)
|
||||
response_content_types.update(_accept_content_type_helper(self.exceptions))
|
||||
|
||||
if not response_content_types.keys():
|
||||
raise TypeError(
|
||||
f"Operation {self.name} has tried to get its accept_content_type even though it has no media types"
|
||||
)
|
||||
|
||||
return ", ".join(list(response_content_types.keys()))
|
||||
|
||||
|
||||
@property
|
||||
def is_stream_request(self) -> bool:
|
||||
"""Is the request is a stream, like an upload."""
|
||||
|
@ -199,7 +156,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
if parameter.skip_url_encoding:
|
||||
optional_parameters.append("skip_quote=True")
|
||||
|
||||
if parameter.style:
|
||||
if parameter.style and not parameter.explode:
|
||||
if parameter.style in [ParameterStyle.simple, ParameterStyle.form]:
|
||||
div_char = ","
|
||||
elif parameter.style in [ParameterStyle.spaceDelimited]:
|
||||
|
@ -212,17 +169,32 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
raise ValueError(f"Do not support {parameter.style} yet")
|
||||
optional_parameters.append(f"div='{div_char}'")
|
||||
|
||||
serialization_constraints = parameter.schema.serialization_constraints
|
||||
optional_parameters += serialization_constraints if serialization_constraints else ""
|
||||
if parameter.explode:
|
||||
if not isinstance(parameter.schema, ListSchema):
|
||||
raise ValueError("Got a explode boolean on a non-array schema")
|
||||
serialization_schema = parameter.schema.element_type
|
||||
else:
|
||||
serialization_schema = parameter.schema
|
||||
|
||||
optional_parameters_string = "" if not optional_parameters else ", " + ", ".join(optional_parameters)
|
||||
serialization_constraints = serialization_schema.serialization_constraints
|
||||
if serialization_constraints:
|
||||
optional_parameters += serialization_constraints
|
||||
|
||||
origin_name = parameter.full_serialized_name
|
||||
|
||||
return (
|
||||
f"""self._serialize.{function_name}("{origin_name.lstrip('_')}", {origin_name}, """
|
||||
+ f"""'{parameter.schema.serialization_type}'{optional_parameters_string})"""
|
||||
)
|
||||
parameters = [
|
||||
f'"{origin_name.lstrip("_")}"',
|
||||
"q" if parameter.explode else origin_name,
|
||||
f"'{serialization_schema.serialization_type}'",
|
||||
*optional_parameters
|
||||
]
|
||||
parameters_line = ', '.join(parameters)
|
||||
|
||||
serialize_line = f'self._serialize.{function_name}({parameters_line})'
|
||||
|
||||
if parameter.explode:
|
||||
return f"[{serialize_line} if q is not None else '' for q in {origin_name}]"
|
||||
return serialize_line
|
||||
|
||||
@property
|
||||
def serialization_context(self) -> str:
|
||||
|
@ -323,7 +295,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
for request in yaml_data["requests"]:
|
||||
for yaml in request.get("parameters", []):
|
||||
parameter = Parameter.from_yaml(yaml)
|
||||
if yaml["language"]["python"]["name"] == "content_type":
|
||||
if yaml["language"]["python"]["name"] in _M4_HEADER_PARAMETERS:
|
||||
parameter.is_kwarg = True
|
||||
parameters.append(parameter)
|
||||
elif multiple_requests:
|
||||
|
@ -332,7 +304,7 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
parameters.append(parameter)
|
||||
|
||||
if multiple_requests:
|
||||
parameters = _remove_multiple_content_type_parameters(parameters)
|
||||
parameters = _remove_multiple_m4_header_parameters(parameters)
|
||||
chosen_parameter = multiple_media_type_parameters[0]
|
||||
|
||||
# binary body parameters are required, while object
|
||||
|
@ -346,10 +318,9 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
parameters.append(chosen_parameter)
|
||||
|
||||
if multiple_media_type_parameters:
|
||||
body_parameters_name_set = {
|
||||
body_parameters_name_set = set(
|
||||
p.serialized_name for p in multiple_media_type_parameters
|
||||
}
|
||||
|
||||
)
|
||||
if len(body_parameters_name_set) > 1:
|
||||
raise ValueError(
|
||||
f"The body parameter with multiple media types has different names: {body_parameters_name_set}"
|
||||
|
@ -374,22 +345,13 @@ class Operation(BaseModel): # pylint: disable=too-many-public-methods, too-many
|
|||
description=yaml_data["language"]["python"]["description"],
|
||||
url=first_request["protocol"]["http"]["path"],
|
||||
method=first_request["protocol"]["http"]["method"],
|
||||
api_versions={
|
||||
value_dict["version"] for value_dict in yaml_data["apiVersions"]
|
||||
},
|
||||
requests=[
|
||||
SchemaRequest.from_yaml(yaml) for yaml in yaml_data["requests"]
|
||||
],
|
||||
multipart=first_request["protocol"]["http"].get("multipart", False),
|
||||
api_versions=set(value_dict["version"] for value_dict in yaml_data["apiVersions"]),
|
||||
requests=[SchemaRequest.from_yaml(yaml) for yaml in yaml_data["requests"]],
|
||||
summary=yaml_data["language"]["python"].get("summary"),
|
||||
parameters=parameters,
|
||||
multiple_media_type_parameters=multiple_media_type_parameters,
|
||||
responses=[
|
||||
SchemaResponse.from_yaml(yaml)
|
||||
for yaml in yaml_data.get("responses", [])
|
||||
],
|
||||
exceptions=[
|
||||
SchemaResponse.from_yaml(yaml)
|
||||
for yaml in yaml_data.get("exceptions", [])
|
||||
if "schema" in yaml
|
||||
],
|
||||
responses=[SchemaResponse.from_yaml(yaml) for yaml in yaml_data.get("responses", [])],
|
||||
# Exception with no schema means default exception, we don't store them
|
||||
exceptions=[SchemaResponse.from_yaml(yaml) for yaml in yaml_data.get("exceptions", []) if "schema" in yaml],
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ from .base_model import BaseModel
|
|||
from .operation import Operation
|
||||
from .lro_operation import LROOperation
|
||||
from .paging_operation import PagingOperation
|
||||
from .lro_paging_operation import LROPagingOperation
|
||||
from .imports import FileImport, ImportType
|
||||
|
||||
|
||||
|
@ -36,14 +37,21 @@ class OperationGroup(BaseModel):
|
|||
self.operations = operations
|
||||
self.api_versions = api_versions
|
||||
|
||||
@staticmethod
|
||||
def imports(async_mode: bool, has_schemas: bool) -> FileImport:
|
||||
def imports(self, async_mode: bool, has_schemas: bool) -> FileImport:
|
||||
file_import = FileImport()
|
||||
file_import.add_import("logging", ImportType.STDLIB)
|
||||
file_import.add_from_import("azure.functions", "HttpRequest",
|
||||
ImportType.AZURECORE)
|
||||
file_import.add_from_import("azure.functions", "HttpResponse",
|
||||
ImportType.AZURECORE)
|
||||
file_import.add_from_import("azure.core.exceptions", "ResourceNotFoundError", ImportType.AZURECORE)
|
||||
file_import.add_from_import("azure.core.exceptions", "ResourceExistsError", ImportType.AZURECORE)
|
||||
for operation in self.operations:
|
||||
file_import.merge(operation.imports(self.code_model, async_mode))
|
||||
if self.code_model.options["tracing"]:
|
||||
if async_mode:
|
||||
file_import.add_from_import(
|
||||
"azure.core.tracing.decorator_async", "distributed_trace_async", ImportType.AZURECORE,
|
||||
)
|
||||
else:
|
||||
file_import.add_from_import(
|
||||
"azure.core.tracing.decorator", "distributed_trace", ImportType.AZURECORE,
|
||||
)
|
||||
if has_schemas:
|
||||
if async_mode:
|
||||
file_import.add_from_import("...", "models", ImportType.LOCAL)
|
||||
|
@ -75,9 +83,13 @@ class OperationGroup(BaseModel):
|
|||
operations = []
|
||||
api_versions: Set[str] = set()
|
||||
for operation_yaml in yaml_data["operations"]:
|
||||
if operation_yaml.get("extensions", {}).get("x-ms-long-running-operation"):
|
||||
lro_operation = operation_yaml.get("extensions", {}).get("x-ms-long-running-operation")
|
||||
paging_operation = operation_yaml.get("extensions", {}).get("x-ms-pageable")
|
||||
if lro_operation and paging_operation:
|
||||
operation = LROPagingOperation.from_yaml(operation_yaml)
|
||||
elif lro_operation:
|
||||
operation = LROOperation.from_yaml(operation_yaml)
|
||||
elif operation_yaml.get("extensions", {}).get("x-ms-pageable"):
|
||||
elif paging_operation:
|
||||
operation = PagingOperation.from_yaml(operation_yaml)
|
||||
else:
|
||||
operation = Operation.from_yaml(operation_yaml)
|
||||
|
|
|
@ -4,62 +4,68 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set, cast
|
||||
from typing import cast, Dict, List, Any, Optional, Set, Union
|
||||
|
||||
from .imports import FileImport, ImportType, TypingSection
|
||||
from .object_schema import ObjectSchema
|
||||
from .operation import Operation
|
||||
from .parameter import Parameter
|
||||
from .schema_request import SchemaRequest
|
||||
from .schema_response import SchemaResponse
|
||||
from .schema_request import SchemaRequest
|
||||
from .imports import ImportType, FileImport, TypingSection
|
||||
from .object_schema import ObjectSchema
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PagingOperation(Operation):
|
||||
def __init__(
|
||||
self,
|
||||
yaml_data: Dict[str, Any],
|
||||
name: str,
|
||||
description: str,
|
||||
url: str,
|
||||
method: str,
|
||||
api_versions: Set[str],
|
||||
requests: List[SchemaRequest],
|
||||
summary: Optional[str] = None,
|
||||
parameters: Optional[List[Parameter]] = None,
|
||||
multiple_media_type_parameters: Optional[List[Parameter]] = None,
|
||||
responses: Optional[List[SchemaResponse]] = None,
|
||||
exceptions: Optional[List[SchemaResponse]] = None,
|
||||
self,
|
||||
yaml_data: Dict[str, Any],
|
||||
name: str,
|
||||
description: str,
|
||||
url: str,
|
||||
method: str,
|
||||
multipart: bool,
|
||||
api_versions: Set[str],
|
||||
requests: List[SchemaRequest],
|
||||
summary: Optional[str] = None,
|
||||
parameters: Optional[List[Parameter]] = None,
|
||||
multiple_media_type_parameters: Optional[List[Parameter]] = None,
|
||||
responses: Optional[List[SchemaResponse]] = None,
|
||||
exceptions: Optional[List[SchemaResponse]] = None,
|
||||
want_description_docstring: bool = True,
|
||||
want_tracing: bool = True,
|
||||
*,
|
||||
override_success_response_to_200: bool = False
|
||||
) -> None:
|
||||
super(PagingOperation, self).__init__(
|
||||
yaml_data,
|
||||
name,
|
||||
description,
|
||||
url,
|
||||
method,
|
||||
api_versions,
|
||||
requests,
|
||||
summary,
|
||||
parameters,
|
||||
multiple_media_type_parameters,
|
||||
responses,
|
||||
exceptions
|
||||
yaml_data,
|
||||
name,
|
||||
description,
|
||||
url,
|
||||
method,
|
||||
multipart,
|
||||
api_versions,
|
||||
requests,
|
||||
summary,
|
||||
parameters,
|
||||
multiple_media_type_parameters,
|
||||
responses,
|
||||
exceptions,
|
||||
want_description_docstring,
|
||||
want_tracing
|
||||
)
|
||||
self._item_name: str = yaml_data["extensions"]["x-ms-pageable"].get(
|
||||
"itemName")
|
||||
self._next_link_name: str = yaml_data["extensions"][
|
||||
"x-ms-pageable"].get("nextLinkName")
|
||||
self.operation_name: str = yaml_data["extensions"]["x-ms-pageable"].get(
|
||||
"operationName")
|
||||
self._item_name: str = yaml_data["extensions"]["x-ms-pageable"].get("itemName")
|
||||
self._next_link_name: str = yaml_data["extensions"]["x-ms-pageable"].get("nextLinkName")
|
||||
self.operation_name: str = yaml_data["extensions"]["x-ms-pageable"].get("operationName")
|
||||
self.next_operation: Optional[Operation] = None
|
||||
self.override_success_response_to_200 = override_success_response_to_200
|
||||
|
||||
def _get_response(self) -> SchemaResponse:
|
||||
response = self.responses[0]
|
||||
if not isinstance(response.schema, ObjectSchema):
|
||||
raise ValueError(
|
||||
"The response of a paging operation must be of type " +
|
||||
f"ObjectSchema but {response.schema} is not"
|
||||
"The response of a paging operation must be of type " +
|
||||
f"ObjectSchema but {response.schema} is not"
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -70,24 +76,26 @@ class PagingOperation(Operation):
|
|||
for prop in response_schema.properties:
|
||||
if prop.original_swagger_name == rest_api_name:
|
||||
return prop.name
|
||||
raise ValueError("While scanning x-ms-pageable, was unable to find "
|
||||
f"{log_name}:{rest_api_name} in model"
|
||||
f" {response_schema.name}")
|
||||
raise ValueError(
|
||||
"While scanning x-ms-pageable, was unable to find "
|
||||
+ f"{log_name}:{rest_api_name} in model {response_schema.name}"
|
||||
)
|
||||
|
||||
@property
|
||||
def item_name(self) -> str:
|
||||
if self._item_name is None:
|
||||
# Default value. I still check if I find it, so I can do a nice
|
||||
# message.
|
||||
# Default value. I still check if I find it,
|
||||
# so I can do a nice message.
|
||||
item_name = "value"
|
||||
try:
|
||||
return self._find_python_name(item_name, "itemName")
|
||||
except ValueError:
|
||||
response = self._get_response()
|
||||
raise ValueError("While scanning x-ms-pageable, itemName was "
|
||||
"not defined and object "
|
||||
f"{response.schema.name} has no array "
|
||||
"called 'value'")
|
||||
raise ValueError(
|
||||
"While scanning x-ms-pageable, itemName was not defined and"
|
||||
" object" + f" {response.schema.name} has no array called "
|
||||
"'value'"
|
||||
)
|
||||
return self._find_python_name(self._item_name, "itemName")
|
||||
|
||||
@property
|
||||
|
@ -99,32 +107,31 @@ class PagingOperation(Operation):
|
|||
|
||||
@property
|
||||
def has_optional_return_type(self) -> bool:
|
||||
"""A paging will never have an optional return type, we will always
|
||||
return ItemPaged[return type]"""
|
||||
"""A paging will never have an optional return type, we will always return ItemPaged[return type]"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def success_status_code(self) -> List[Union[str, int]]:
|
||||
"""The list of all successfull status code.
|
||||
"""
|
||||
if self.override_success_response_to_200:
|
||||
return [200]
|
||||
return super(PagingOperation, self).success_status_code
|
||||
|
||||
def imports(self, code_model, async_mode: bool) -> FileImport:
|
||||
file_import = super(PagingOperation, self).imports(code_model,
|
||||
async_mode)
|
||||
file_import = super(PagingOperation, self).imports(code_model, async_mode)
|
||||
|
||||
if async_mode:
|
||||
file_import.add_from_import("azure.core.async_paging",
|
||||
"AsyncItemPaged", ImportType.AZURECORE)
|
||||
file_import.add_from_import("azure.core.async_paging", "AsyncList",
|
||||
ImportType.AZURECORE)
|
||||
file_import.add_from_import("typing", "AsyncIterable",
|
||||
ImportType.STDLIB,
|
||||
TypingSection.CONDITIONAL)
|
||||
file_import.add_from_import("azure.core.async_paging", "AsyncItemPaged", ImportType.AZURECORE)
|
||||
file_import.add_from_import("azure.core.async_paging", "AsyncList", ImportType.AZURECORE)
|
||||
file_import.add_from_import("typing", "AsyncIterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
|
||||
else:
|
||||
file_import.add_from_import("azure.core.paging", "ItemPaged",
|
||||
ImportType.AZURECORE)
|
||||
file_import.add_from_import("typing", "Iterable", ImportType.STDLIB,
|
||||
TypingSection.CONDITIONAL)
|
||||
file_import.add_from_import("azure.core.paging", "ItemPaged", ImportType.AZURECORE)
|
||||
file_import.add_from_import("typing", "Iterable", ImportType.STDLIB, TypingSection.CONDITIONAL)
|
||||
|
||||
if code_model.options["tracing"]:
|
||||
file_import.add_from_import(
|
||||
"azure.core.tracing.decorator", "distributed_trace",
|
||||
ImportType.AZURECORE,
|
||||
"azure.core.tracing.decorator", "distributed_trace", ImportType.AZURECORE,
|
||||
)
|
||||
|
||||
return file_import
|
||||
|
|
|
@ -37,6 +37,7 @@ class ParameterStyle(Enum):
|
|||
json = "json"
|
||||
binary = "binary"
|
||||
xml = "xml"
|
||||
multipart = "multipart"
|
||||
|
||||
|
||||
class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
|
||||
|
@ -54,6 +55,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
|
|||
constraints: List[Any],
|
||||
target_property_name: Optional[Union[int, str]] = None, # first uses id as placeholder
|
||||
style: Optional[ParameterStyle] = None,
|
||||
explode: Optional[bool] = False,
|
||||
*,
|
||||
flattened: bool = False,
|
||||
grouped_by: Optional["Parameter"] = None,
|
||||
|
@ -72,6 +74,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
|
|||
self.constraints = constraints
|
||||
self.target_property_name = target_property_name
|
||||
self.style = style
|
||||
self.explode = explode
|
||||
self.flattened = flattened
|
||||
self.grouped_by = grouped_by
|
||||
self.original_parameter = original_parameter
|
||||
|
@ -131,7 +134,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
|
|||
default_value_declaration = "None"
|
||||
else:
|
||||
if isinstance(self.schema, ConstantSchema):
|
||||
default_value = self.schema.constant_value
|
||||
default_value = self.schema.get_declaration(self.schema.value)
|
||||
default_value_declaration = default_value
|
||||
else:
|
||||
default_value = self.schema.default_value
|
||||
|
@ -194,6 +197,7 @@ class Parameter(BaseModel): # pylint: disable=too-many-instance-attributes
|
|||
constraints=[], # FIXME constraints
|
||||
target_property_name=id(yaml_data["targetProperty"]) if yaml_data.get("targetProperty") else None,
|
||||
style=ParameterStyle(http_protocol["style"]) if "style" in http_protocol else None,
|
||||
explode=http_protocol.get("explode", False),
|
||||
grouped_by=yaml_data.get("groupedBy", None),
|
||||
original_parameter=yaml_data.get("originalParameter", None),
|
||||
flattened=yaml_data.get("flattened", False),
|
||||
|
|
|
@ -56,11 +56,11 @@ class ParameterList(MutableSequence):
|
|||
return self.has_any_location(ParameterLocation.Body)
|
||||
|
||||
@property
|
||||
def body(self) -> Parameter:
|
||||
def body(self) -> List[Parameter]:
|
||||
if not self.has_body:
|
||||
raise ValueError("Can't get body parameter")
|
||||
# Should we check if there is two body? Modeler role right?
|
||||
return self.get_from_location(ParameterLocation.Body)[0]
|
||||
return self.get_from_location(ParameterLocation.Body)
|
||||
|
||||
@property
|
||||
def path(self) -> List[Parameter]:
|
||||
|
@ -140,5 +140,5 @@ class ParameterList(MutableSequence):
|
|||
for param in parameters if param.target_property_name
|
||||
]
|
||||
)
|
||||
object_schema = cast(ObjectSchema, self.body.schema)
|
||||
return f"{self.body.serialized_name} = models.{object_schema.name}({parameter_string})"
|
||||
object_schema = cast(ObjectSchema, self.body[0].schema)
|
||||
return f"{self.body[0].serialized_name} = models.{object_schema.name}({parameter_string})"
|
||||
|
|
|
@ -19,7 +19,7 @@ class RawString(object):
|
|||
self.string = string
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"r'{self.string}'"
|
||||
return "r'{}'".format(self.string.replace('\'', '\\\''))
|
||||
|
||||
|
||||
class PrimitiveSchema(BaseSchema):
|
||||
|
|
|
@ -4,8 +4,12 @@
|
|||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from jinja2 import Environment
|
||||
from autorest.codegen.serializers.azure_functions.python.import_serializer import FileImportSerializer, TypingSection
|
||||
from ..models import FileImport, ImportType, CodeModel, CredentialSchema
|
||||
|
||||
from autorest.codegen.models import TypingSection
|
||||
from autorest.codegen.serializers.azure_functions.python.import_serializer \
|
||||
import \
|
||||
FileImportSerializer
|
||||
from ..models import FileImport, ImportType, CodeModel, TokenCredentialSchema
|
||||
|
||||
|
||||
class GeneralSerializer:
|
||||
|
@ -24,9 +28,9 @@ class GeneralSerializer:
|
|||
|
||||
def _correct_credential_parameter(self):
|
||||
credential_param = [
|
||||
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
|
||||
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, TokenCredentialSchema)
|
||||
][0]
|
||||
credential_param.schema = CredentialSchema(async_mode=self.async_mode)
|
||||
credential_param.schema = TokenCredentialSchema(async_mode=self.async_mode)
|
||||
|
||||
def serialize_service_client_file(self) -> str:
|
||||
def _service_client_imports() -> FileImport:
|
||||
|
@ -37,7 +41,10 @@ class GeneralSerializer:
|
|||
|
||||
template = self.env.get_template("service_client.py.jinja2")
|
||||
|
||||
if self.code_model.options['credential']:
|
||||
if (
|
||||
self.code_model.options['credential'] and
|
||||
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
|
||||
):
|
||||
self._correct_credential_parameter()
|
||||
|
||||
return template.render(
|
||||
|
@ -59,6 +66,8 @@ class GeneralSerializer:
|
|||
file_import.add_from_import(".._version" if async_mode else "._version", "VERSION", ImportType.LOCAL)
|
||||
for gp in self.code_model.global_parameters:
|
||||
file_import.merge(gp.imports())
|
||||
if self.code_model.options["azure_arm"]:
|
||||
file_import.add_from_import("azure.mgmt.core.policies", "ARMHttpLoggingPolicy", ImportType.AZURECORE)
|
||||
return file_import
|
||||
|
||||
package_name = self.code_model.options['package_name']
|
||||
|
@ -66,7 +75,10 @@ class GeneralSerializer:
|
|||
package_name = package_name[len("azure-"):]
|
||||
sdk_moniker = package_name if package_name else self.code_model.class_name.lower()
|
||||
|
||||
if self.code_model.options['credential']:
|
||||
if (
|
||||
self.code_model.options['credential'] and
|
||||
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
|
||||
):
|
||||
self._correct_credential_parameter()
|
||||
|
||||
template = self.env.get_template("config.py.jinja2")
|
||||
|
|
|
@ -5,25 +5,22 @@
|
|||
# --------------------------------------------------------------------------
|
||||
import copy
|
||||
import json
|
||||
from typing import List, Optional, Set, Tuple, Dict
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from jinja2 import Environment
|
||||
from ..models import (
|
||||
CodeModel,
|
||||
Operation,
|
||||
OperationGroup,
|
||||
LROOperation,
|
||||
PagingOperation,
|
||||
CredentialSchema,
|
||||
ParameterList,
|
||||
TypingSection,
|
||||
ImportType
|
||||
)
|
||||
|
||||
from ..models import (CodeModel, ImportType, LROOperation, Operation,
|
||||
OperationGroup, PagingOperation, ParameterList,
|
||||
TokenCredentialSchema, TypingSection)
|
||||
|
||||
|
||||
def _correct_credential_parameter(global_parameters: ParameterList, async_mode: bool) -> None:
|
||||
credential_param = [
|
||||
gp for gp in global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
|
||||
gp for gp in global_parameters.parameters if
|
||||
isinstance(gp.schema, TokenCredentialSchema)
|
||||
][0]
|
||||
credential_param.schema = CredentialSchema(async_mode=async_mode)
|
||||
credential_param.schema = TokenCredentialSchema(async_mode=async_mode)
|
||||
|
||||
|
||||
def _json_serialize_imports(
|
||||
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
|
||||
|
@ -32,9 +29,9 @@ def _json_serialize_imports(
|
|||
return None
|
||||
|
||||
json_serialize_imports = {}
|
||||
# need to make name_import set -> list to make the dictionary json serializable
|
||||
# not using an OrderedDict since we're iterating through a set and the order there varies
|
||||
# going to sort the list instead
|
||||
# need to make name_import () -> [] to make the dictionary json serializable
|
||||
# not using an OrderedDict since we're iterating through a set and the order
|
||||
# there varies going to sort the list instead
|
||||
|
||||
for typing_section_key, typing_section_value in imports.items():
|
||||
json_import_type_dictionary = {}
|
||||
|
@ -107,11 +104,19 @@ class MetadataSerializer:
|
|||
# In this case, we need two copies of the credential global parameter
|
||||
# for typing purposes.
|
||||
async_global_parameters = self.code_model.global_parameters
|
||||
if self.code_model.options['credential']:
|
||||
# this ensures that the CredentialSchema showing up in the list of code model's global parameters
|
||||
# is sync. This way we only have to make a copy for an async_credential
|
||||
_correct_credential_parameter(self.code_model.global_parameters, False)
|
||||
async_global_parameters = self._make_async_copy_of_global_parameters()
|
||||
if (
|
||||
self.code_model.options['credential'] and
|
||||
self.code_model.options[
|
||||
'credential_default_policy_type'] ==
|
||||
"BearerTokenCredentialPolicy"
|
||||
):
|
||||
# this ensures that the TokenCredentialSchema showing up in the
|
||||
# list of code model's global parameters is sync.
|
||||
# This way we only have to make a copy for an async_credential
|
||||
_correct_credential_parameter(self.code_model.global_parameters,
|
||||
False)
|
||||
async_global_parameters = \
|
||||
self._make_async_copy_of_global_parameters()
|
||||
|
||||
template = self.env.get_template("metadata.json.jinja2")
|
||||
return template.render(
|
||||
|
|
|
@ -5,13 +5,14 @@
|
|||
# --------------------------------------------------------------------------
|
||||
from jinja2 import Environment
|
||||
|
||||
from autorest.codegen.serializers.azure_functions.python.import_serializer import FileImportSerializer
|
||||
from ..models import LROOperation, PagingOperation, CodeModel, OperationGroup
|
||||
from .azure_functions.python.import_serializer import FileImportSerializer
|
||||
from ..models import CodeModel, LROOperation, OperationGroup, PagingOperation
|
||||
|
||||
|
||||
class OperationGroupSerializer:
|
||||
def __init__(
|
||||
self, code_model: CodeModel, env: Environment, operation_group: OperationGroup, async_mode: bool
|
||||
self, code_model: CodeModel, env: Environment,
|
||||
operation_group: OperationGroup, async_mode: bool
|
||||
) -> None:
|
||||
self.code_model = code_model
|
||||
self.env = env
|
||||
|
@ -25,18 +26,25 @@ class OperationGroupSerializer:
|
|||
def _is_paging(operation):
|
||||
return isinstance(operation, PagingOperation)
|
||||
|
||||
operation_group_template = self.env.get_template("operations_container.py.jinja2")
|
||||
operation_group_template = self.env.get_template(
|
||||
"operations_container.py.jinja2")
|
||||
if self.operation_group.is_empty_operation_group:
|
||||
operation_group_template = self.env.get_template("operations_container_mixin.py.jinja2")
|
||||
operation_group_template = self.env.get_template(
|
||||
"operations_container_mixin.py.jinja2")
|
||||
|
||||
return operation_group_template.render(
|
||||
code_model=self.code_model,
|
||||
operation_group=self.operation_group,
|
||||
imports=FileImportSerializer(
|
||||
self.operation_group.imports(self.async_mode, bool(self.code_model.schemas)),
|
||||
is_python_3_file=self.async_mode
|
||||
),
|
||||
async_mode=self.async_mode,
|
||||
is_lro=_is_lro,
|
||||
is_paging=_is_paging,
|
||||
code_model=self.code_model,
|
||||
operation_group=self.operation_group,
|
||||
imports=FileImportSerializer(
|
||||
self.operation_group.imports(
|
||||
self.async_mode,
|
||||
bool(
|
||||
self.code_model.schemas or
|
||||
self.code_model.enums)
|
||||
),
|
||||
is_python_3_file=self.async_mode
|
||||
),
|
||||
async_mode=self.async_mode,
|
||||
is_lro=_is_lro,
|
||||
is_paging=_is_paging,
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
@dispatcher.add_method
|
||||
def GetPluginNames():
|
||||
return ["codegen", "m2r", "namer", "namercsharp", "multiapiscript"]
|
||||
return ["codegen", "m2r", "namer"]
|
||||
|
||||
|
||||
@dispatcher.add_method
|
||||
|
@ -34,12 +34,8 @@ def Process(plugin_name: str, session_id: str) -> bool:
|
|||
from ..m2r import M2R as PluginToLoad
|
||||
elif plugin_name == "namer":
|
||||
from ..namer import Namer as PluginToLoad # type: ignore
|
||||
elif plugin_name == "namercsharp":
|
||||
from ..namer.csharp import Namer as PluginToLoad # type: ignore
|
||||
elif plugin_name == "codegen":
|
||||
from ..codegen import CodeGenerator as PluginToLoad # type: ignore
|
||||
elif plugin_name == "multiapiscript":
|
||||
from ..multiapi import MultiApiScriptPlugin as PluginToLoad # type: ignore
|
||||
else:
|
||||
_LOGGER.fatal("Unknown plugin name %s", plugin_name)
|
||||
raise RuntimeError(f"Unknown plugin name {plugin_name}")
|
||||
|
|
|
@ -82,7 +82,7 @@ class NameConverter:
|
|||
def _convert_enum_schema(schema: Dict[str, Any]) -> None:
|
||||
NameConverter._convert_language_default_pascal_case(schema)
|
||||
for choice in schema["choices"]:
|
||||
NameConverter._convert_language_default_python_case(choice, pad_string=PadType.Enum)
|
||||
NameConverter._convert_language_default_python_case(choice, pad_string=PadType.Enum, all_upper=True)
|
||||
|
||||
@staticmethod
|
||||
def _convert_object_schema(schema: Dict[str, Any]) -> None:
|
||||
|
@ -97,9 +97,21 @@ class NameConverter:
|
|||
for prop in schema.get("properties", []):
|
||||
NameConverter._convert_language_default_python_case(schema=prop, pad_string=PadType.Property)
|
||||
|
||||
@staticmethod
|
||||
def _is_schema_an_m4_header_parameter(schema_name: str, schema: Dict[str, Any]) -> bool:
|
||||
m4_header_parameters = ["content_type", "accept"]
|
||||
return (
|
||||
schema_name in m4_header_parameters and
|
||||
schema.get('protocol', {}).get('http', {}).get('in', {}) == 'header'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_language_default_python_case(
|
||||
schema: Dict[str, Any], *, pad_string: Optional[PadType] = None, convert_name: bool = False
|
||||
schema: Dict[str, Any],
|
||||
*,
|
||||
pad_string: Optional[PadType] = None,
|
||||
convert_name: bool = False,
|
||||
all_upper: bool = False
|
||||
) -> None:
|
||||
if not schema.get("language") or schema["language"].get("python"):
|
||||
return
|
||||
|
@ -107,17 +119,19 @@ class NameConverter:
|
|||
schema_name = schema['language']['default']['name']
|
||||
schema_python_name = schema['language']['python']['name']
|
||||
|
||||
if not(
|
||||
schema_name == 'content_type' and
|
||||
schema.get('protocol') and
|
||||
schema['protocol'].get('http') and
|
||||
schema['protocol']['http']['in'] == "header"
|
||||
if not NameConverter._is_schema_an_m4_header_parameter(
|
||||
schema_name, schema
|
||||
):
|
||||
# only escaping name if it's not a content_type header parameter
|
||||
schema_python_name = NameConverter._to_valid_python_name(
|
||||
name=schema_name, pad_string=pad_string, convert_name=convert_name
|
||||
)
|
||||
schema['language']['python']['name'] = schema_python_name.lower()
|
||||
# need to add the lower in case certain words, like LRO, are overriden to
|
||||
# always return LRO. Without .lower(), for example, begin_lro would be
|
||||
# begin_LRO
|
||||
schema['language']['python']['name'] = (
|
||||
schema_python_name.upper() if all_upper else schema_python_name.lower()
|
||||
)
|
||||
|
||||
schema_description = schema["language"]["default"]["description"].strip()
|
||||
if pad_string == PadType.Method and not schema_description and not schema["language"]["default"].get("summary"):
|
||||
|
|
|
@ -100,6 +100,7 @@ reserved_words = {
|
|||
"self",
|
||||
# these are kwargs we've reserved for our autorest generated operations
|
||||
"content_type",
|
||||
"accept",
|
||||
"cls",
|
||||
"polling",
|
||||
"continuation_token", # for LRO calls
|
||||
|
|
Загрузка…
Ссылка в новой задаче