From 01635e1889ab250e74b5ccc1b6a4d2ba25ac0aca Mon Sep 17 00:00:00 2001 From: iscai-msft Date: Tue, 23 Apr 2024 17:59:38 -0400 Subject: [PATCH] initial try to create deserializer class --- .../codegen/templates/model_base.py.jinja2 | 334 ++++++++---------- 1 file changed, 142 insertions(+), 192 deletions(-) diff --git a/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 b/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 index d0a073f242..ca92cb23ea 100644 --- a/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 +++ b/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 @@ -16,7 +16,7 @@ import re import typing import enum import email.utils -from datetime import datetime, date, time, timedelta, timezone +import datetime from json import JSONEncoder from typing_extensions import Self import isodate @@ -38,7 +38,7 @@ TZ_UTC = timezone.utc _T = typing.TypeVar("_T") -def _timedelta_as_isostr(td: timedelta) -> str: +def _timedelta_as_isostr(td: datetime.timedelta) -> str: """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython @@ -170,15 +170,14 @@ _VALID_RFC7231 = re.compile( r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT" ) - -def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: +def _deserialize_datetime(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime: """Deserialize ISO-8601 formatted string into Datetime object. :param str attr: response string to be deserialized. :rtype: ~datetime.datetime :returns: The datetime object from that input """ - if isinstance(attr, datetime): + if isinstance(attr, datetime.datetime): # i'm already deserialized return attr attr = attr.upper() @@ -203,15 +202,14 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime: raise OverflowError("Hit max or min date") return date_obj - -def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime: +def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime: """Deserialize RFC7231 formatted string into Datetime object. :param str attr: response string to be deserialized. :rtype: ~datetime.datetime :returns: The datetime object from that input """ - if isinstance(attr, datetime): + if isinstance(attr, datetime.datetime): # i'm already deserialized return attr match = _VALID_RFC7231.match(attr) @@ -220,91 +218,118 @@ def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime return email.utils.parsedate_to_datetime(attr) +class SdkDecoder: + @staticmethod + def deserialize_datetime( + attr: typing.Union[str, datetime.datetime], + *, + encode: typing.Union[typing.Literal["rfc3339"], typing.Literal["rfc7231"]] = "rfc3339", + ) -> datetime.datetime: + if encode == "rfc3339": + return _deserialize_datetime(attr) + return _deserialize_datetime_rfc7231(attr) + + @staticmethod + def deserialize_date(attr: typing.Union[str, datetime.date]) -> datetime.date: + """Deserialize ISO-8601 formatted string into Date object. + :param str attr: response string to be deserialized. + :rtype: date + :returns: The date object from that input + """ + # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. + if isinstance(attr, datetime.date): + return attr + return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore + + @staticmethod + def deserialize_bytes( + attr: typing.Union[str, bytes], + *, + encode: typing.Union[typing.Literal["base64"], typing.Literal["base64url"]] = "base64", + ) -> bytes: + if isinstance(attr, (bytes, bytearray)): + return attr + if encode == "base64url": + padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore + attr = attr + padding # type: ignore + encoded = attr.replace("-", "+").replace("_", "/") + else: + encoded = attr + return bytes(base64.b64decode(encoded)) -def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime: - """Deserialize unix timestamp into Datetime object. + + @staticmethod + def deserialize_duration(attr: str) -> datetime.timedelta: + if isinstance(attr, datetime.timedelta): + return attr + return isodate.parse_duration(attr) - :param str attr: response string to be deserialized. - :rtype: ~datetime.datetime - :returns: The datetime object from that input - """ - if isinstance(attr, datetime): - # i'm already deserialized + @staticmethod + def deserialize_time(attr: typing.Union[str, datetime.time]) -> datetime.time: + """Deserialize ISO-8601 formatted string into time object. + + :param str attr: response string to be deserialized. + :rtype: datetime.time + :returns: The time object from that input + """ + if isinstance(attr, datetime.time): + return attr + return isodate.parse_time(attr) + + @staticmethod + def deserialize_any(attr: typing.Any) -> typing.Any: return attr - return datetime.fromtimestamp(attr, TZ_UTC) - - -def _deserialize_date(attr: typing.Union[str, date]) -> date: - """Deserialize ISO-8601 formatted string into Date object. - :param str attr: response string to be deserialized. - :rtype: date - :returns: The date object from that input - """ - # This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception. - if isinstance(attr, date): - return attr - return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore - - -def _deserialize_time(attr: typing.Union[str, time]) -> time: - """Deserialize ISO-8601 formatted string into time object. - - :param str attr: response string to be deserialized. - :rtype: datetime.time - :returns: The time object from that input - """ - if isinstance(attr, time): - return attr - return isodate.parse_time(attr) - - -def _deserialize_bytes(attr): - if isinstance(attr, (bytes, bytearray)): - return attr - return bytes(base64.b64decode(attr)) - - -def _deserialize_bytes_base64(attr): - if isinstance(attr, (bytes, bytearray)): - return attr - padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore - attr = attr + padding # type: ignore - encoded = attr.replace("-", "+").replace("_", "/") - return bytes(base64.b64decode(encoded)) - - -def _deserialize_duration(attr): - if isinstance(attr, timedelta): - return attr - return isodate.parse_duration(attr) - - -def _deserialize_decimal(attr): - if isinstance(attr, decimal.Decimal): - return attr - return decimal.Decimal(str(attr)) + + @staticmethod + def deserialize_decimal(attr: str) -> decimal.Decimal: + if isinstance(attr, decimal.Decimal): + return attr + return decimal.Decimal(str(attr)) + + @staticmethod + def deserialize_union(deserializers: typing.List[typing.Callable], attr: typing.Any) -> typing.Any: + for deserializer in deserializers: + try: + return _deserialize(deserializer, attr) + except DeserializationError: + pass + raise DeserializationError() + + @staticmethod + def deserialize_dict( + value_deserializer: typing.Optional[typing.Callable], + attr: typing.Dict[typing.Any, typing.Any], + ): + if attr is None: + return attr + return {k: _deserialize(value_deserializer, v) for k, v in attr.items()} + + @staticmethod + def deserialize_sequence( + entry_deserializers: typing.List[typing.Callable], + obj: typing.Sequence, + ): + return type(obj)(_deserialize(deserializer, entry) for entry, deserializer in zip(obj, entry_deserializers)) _DESERIALIZE_MAPPING = { - datetime: _deserialize_datetime, - date: _deserialize_date, - time: _deserialize_time, - bytes: _deserialize_bytes, - bytearray: _deserialize_bytes, - timedelta: _deserialize_duration, - typing.Any: lambda x: x, - decimal.Decimal: _deserialize_decimal, + datetime.datetime: SdkDecoder.deserialize_datetime, + datetime.date: SdkDecoder.deserialize_date, + datetime.time: SdkDecoder.deserialize_time, + bytes: SdkDecoder.deserialize_bytes, + bytearray: SdkDecoder.deserialize_bytes, + datetime.timedelta: SdkDecoder.deserialize_duration, + typing.Any: SdkDecoder.deserialize_any, + decimal.Decimal: SdkDecoder.deserialize_decimal, } _DESERIALIZE_MAPPING_WITHFORMAT = { - "rfc3339": _deserialize_datetime, - "rfc7231": _deserialize_datetime_rfc7231, - "unix-timestamp": _deserialize_datetime_unix_timestamp, - "base64": _deserialize_bytes, - "base64url": _deserialize_bytes_base64, + "rfc3339": SdkDecoder.deserialize_datetime, + "rfc7231": functools.partial(SdkDecoder.deserialize_datetime, encode="rfc7231"), + "base64": SdkDecoder.deserialize_bytes, + "base64url": functools.partial(SdkDecoder.deserialize_bytes, encode="base64url") } - def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): if rf and rf._format: return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) @@ -323,17 +348,9 @@ def _get_type_alias_type(module_name: str, alias_name: str): def _get_model(module_name: str, model_name: str): - models = { - k: v - for k, v in sys.modules[module_name].__dict__.items() - if isinstance(v, type) - } + models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - models.update({ - k: v - for k, v in sys.modules[module_end].__dict__.items() - if isinstance(v, type) - }) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -385,16 +402,13 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns return default @typing.overload - def pop(self, key: str) -> typing.Any: - ... + def pop(self, key: str) -> typing.Any: ... @typing.overload - def pop(self, key: str, default: _T) -> _T: - ... + def pop(self, key: str, default: _T) -> _T: ... @typing.overload - def pop(self, key: str, default: typing.Any) -> typing.Any: - ... + def pop(self, key: str, default: typing.Any) -> typing.Any: ... def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any: if default is _UNSET: @@ -411,12 +425,10 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns self._data.update(*args, **kwargs) @typing.overload - def setdefault(self, key: str, default: None = None) -> None: - ... + def setdefault(self, key: str, default: None = None) -> None: ... @typing.overload - def setdefault(self, key: str, default: typing.Any) -> typing.Any: - ... + def setdefault(self, key: str, default: typing.Any) -> typing.Any: ... def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any: if default is _UNSET: @@ -549,7 +561,9 @@ class Model(_MyMutableMapping): @classmethod def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @@ -559,9 +573,7 @@ class Model(_MyMutableMapping): return cls(data) discriminator = cls._get_discriminator(exist_discriminators) exist_discriminators.append(discriminator) - mapped_cls = cls.__mapping__.get( - data.get(discriminator), cls - ) # pyright: ignore # pylint: disable=no-member + mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore # pylint: disable=no-member if mapped_cls == cls: return cls(data) return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access @@ -582,7 +594,9 @@ class Model(_MyMutableMapping): continue is_multipart_file_input = False try: - is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input + is_multipart_file_input = next( + rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k + )._is_multipart_file_input except StopIteration: pass result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly) @@ -593,67 +607,11 @@ class Model(_MyMutableMapping): if v is None or isinstance(v, _Null): return None if isinstance(v, (list, tuple, set)): - return type(v)( - Model._as_dict_value(x, exclude_readonly=exclude_readonly) - for x in v - ) + return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v) if isinstance(v, dict): - return { - dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) - for dk, dv in v.items() - } + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj): - if _is_model(obj): - return obj - return _deserialize(model_deserializer, obj) - -def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj): - if obj is None: - return obj - return _deserialize_with_callable(if_obj_deserializer, obj) - -def _deserialize_with_union(deserializers, obj): - for deserializer in deserializers: - try: - return _deserialize(deserializer, obj) - except DeserializationError: - pass - raise DeserializationError() - -def _deserialize_dict( - value_deserializer: typing.Optional[typing.Callable], - module: typing.Optional[str], - obj: typing.Dict[typing.Any, typing.Any], -): - if obj is None: - return obj - return { - k: _deserialize(value_deserializer, v, module) - for k, v in obj.items() - } - -def _deserialize_multiple_sequence( - entry_deserializers: typing.List[typing.Optional[typing.Callable]], - module: typing.Optional[str], - obj, -): - if obj is None: - return obj - return type(obj)( - _deserialize(deserializer, entry, module) - for entry, deserializer in zip(obj, entry_deserializers) - ) - -def _deserialize_sequence( - deserializer: typing.Optional[typing.Callable], - module: typing.Optional[str], - obj, -): - if obj is None: - return obj - return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, @@ -693,27 +651,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, except AttributeError: pass - # is it optional? - try: - if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore - if_obj_deserializer = _get_deserialize_callable_from_annotation( - next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore - ) - - return functools.partial(_deserialize_with_optional, if_obj_deserializer) - except AttributeError: - pass - if getattr(annotation, "__origin__", None) is typing.Union: # initial ordering is we make `string` the last deserialization option, because it is often them most generic - deserializers = [ + deserializers: typing.List[typing.Callable] = [ _get_deserialize_callable_from_annotation(arg, module, rf) for arg in sorted( - annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore - ) + annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore + ) ] - return functools.partial(_deserialize_with_union, deserializers) + return functools.partial(SdkDecoder.deserialize_union, deserializers) try: if annotation._name == "Dict": # pyright: ignore @@ -721,11 +668,9 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, annotation.__args__[1], module, rf # pyright: ignore ) - return functools.partial( - _deserialize_dict, + SdkDecoder.deserialize_dict, value_deserializer, - module, ) except (AttributeError, IndexError): pass @@ -733,18 +678,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore if len(annotation.__args__) > 1: # pyright: ignore - - entry_deserializers = [ - _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore + entry_deserializers: typing.List[typing.Callable] = [ + _get_deserialize_callable_from_annotation(dt, module, rf) + for dt in annotation.__args__ # pyright: ignore ] - return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module) - deserializer = _get_deserialize_callable_from_annotation( + return functools.partial(SdkDecoder.deserialize_sequence, entry_deserializers) + deserializer: typing.Callable = _get_deserialize_callable_from_annotation( annotation.__args__[0], module, rf # pyright: ignore ) - - - return functools.partial(_deserialize_sequence, deserializer, module) + return functools.partial(SdkDecoder.deserialize_sequence, [deserializer]) except (TypeError, IndexError, AttributeError, SyntaxError): pass @@ -876,7 +819,14 @@ def rest_field( format: typing.Optional[str] = None, is_multipart_file_input: bool = False, ) -> typing.Any: - return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input) + return _RestField( + name=name, + type=type, + visibility=visibility, + default=default, + format=format, + is_multipart_file_input=is_multipart_file_input, + ) def rest_discriminator(