initial try to create deserializer class

This commit is contained in:
iscai-msft 2024-04-23 17:59:38 -04:00
Родитель 3ebf8dcc62
Коммит 01635e1889
1 изменённых файлов: 142 добавлений и 192 удалений

Просмотреть файл

@ -16,7 +16,7 @@ import re
import typing
import enum
import email.utils
from datetime import datetime, date, time, timedelta, timezone
import datetime
from json import JSONEncoder
from typing_extensions import Self
import isodate
@ -38,7 +38,7 @@ TZ_UTC = timezone.utc
_T = typing.TypeVar("_T")
def _timedelta_as_isostr(td: timedelta) -> str:
def _timedelta_as_isostr(td: datetime.timedelta) -> str:
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
@ -170,15 +170,14 @@ _VALID_RFC7231 = re.compile(
r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
)
def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
def _deserialize_datetime(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime:
"""Deserialize ISO-8601 formatted string into Datetime object.
:param str attr: response string to be deserialized.
:rtype: ~datetime.datetime
:returns: The datetime object from that input
"""
if isinstance(attr, datetime):
if isinstance(attr, datetime.datetime):
# i'm already deserialized
return attr
attr = attr.upper()
@ -203,15 +202,14 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
raise OverflowError("Hit max or min date")
return date_obj
def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime:
def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime:
"""Deserialize RFC7231 formatted string into Datetime object.
:param str attr: response string to be deserialized.
:rtype: ~datetime.datetime
:returns: The datetime object from that input
"""
if isinstance(attr, datetime):
if isinstance(attr, datetime.datetime):
# i'm already deserialized
return attr
match = _VALID_RFC7231.match(attr)
@ -220,91 +218,118 @@ def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime
return email.utils.parsedate_to_datetime(attr)
class SdkDecoder:
@staticmethod
def deserialize_datetime(
attr: typing.Union[str, datetime.datetime],
*,
encode: typing.Union[typing.Literal["rfc3339"], typing.Literal["rfc7231"]] = "rfc3339",
) -> datetime.datetime:
if encode == "rfc3339":
return _deserialize_datetime(attr)
return _deserialize_datetime_rfc7231(attr)
@staticmethod
def deserialize_date(attr: typing.Union[str, datetime.date]) -> datetime.date:
"""Deserialize ISO-8601 formatted string into Date object.
:param str attr: response string to be deserialized.
:rtype: date
:returns: The date object from that input
"""
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
if isinstance(attr, datetime.date):
return attr
return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
@staticmethod
def deserialize_bytes(
attr: typing.Union[str, bytes],
*,
encode: typing.Union[typing.Literal["base64"], typing.Literal["base64url"]] = "base64",
) -> bytes:
if isinstance(attr, (bytes, bytearray)):
return attr
if encode == "base64url":
padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
attr = attr + padding # type: ignore
encoded = attr.replace("-", "+").replace("_", "/")
else:
encoded = attr
return bytes(base64.b64decode(encoded))
def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime:
"""Deserialize unix timestamp into Datetime object.
@staticmethod
def deserialize_duration(attr: str) -> datetime.timedelta:
if isinstance(attr, datetime.timedelta):
return attr
return isodate.parse_duration(attr)
:param str attr: response string to be deserialized.
:rtype: ~datetime.datetime
:returns: The datetime object from that input
"""
if isinstance(attr, datetime):
# i'm already deserialized
@staticmethod
def deserialize_time(attr: typing.Union[str, datetime.time]) -> datetime.time:
"""Deserialize ISO-8601 formatted string into time object.
:param str attr: response string to be deserialized.
:rtype: datetime.time
:returns: The time object from that input
"""
if isinstance(attr, datetime.time):
return attr
return isodate.parse_time(attr)
@staticmethod
def deserialize_any(attr: typing.Any) -> typing.Any:
return attr
return datetime.fromtimestamp(attr, TZ_UTC)
def _deserialize_date(attr: typing.Union[str, date]) -> date:
"""Deserialize ISO-8601 formatted string into Date object.
:param str attr: response string to be deserialized.
:rtype: date
:returns: The date object from that input
"""
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
if isinstance(attr, date):
return attr
return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
def _deserialize_time(attr: typing.Union[str, time]) -> time:
"""Deserialize ISO-8601 formatted string into time object.
:param str attr: response string to be deserialized.
:rtype: datetime.time
:returns: The time object from that input
"""
if isinstance(attr, time):
return attr
return isodate.parse_time(attr)
def _deserialize_bytes(attr):
if isinstance(attr, (bytes, bytearray)):
return attr
return bytes(base64.b64decode(attr))
def _deserialize_bytes_base64(attr):
if isinstance(attr, (bytes, bytearray)):
return attr
padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
attr = attr + padding # type: ignore
encoded = attr.replace("-", "+").replace("_", "/")
return bytes(base64.b64decode(encoded))
def _deserialize_duration(attr):
if isinstance(attr, timedelta):
return attr
return isodate.parse_duration(attr)
def _deserialize_decimal(attr):
if isinstance(attr, decimal.Decimal):
return attr
return decimal.Decimal(str(attr))
@staticmethod
def deserialize_decimal(attr: str) -> decimal.Decimal:
if isinstance(attr, decimal.Decimal):
return attr
return decimal.Decimal(str(attr))
@staticmethod
def deserialize_union(deserializers: typing.List[typing.Callable], attr: typing.Any) -> typing.Any:
for deserializer in deserializers:
try:
return _deserialize(deserializer, attr)
except DeserializationError:
pass
raise DeserializationError()
@staticmethod
def deserialize_dict(
value_deserializer: typing.Optional[typing.Callable],
attr: typing.Dict[typing.Any, typing.Any],
):
if attr is None:
return attr
return {k: _deserialize(value_deserializer, v) for k, v in attr.items()}
@staticmethod
def deserialize_sequence(
entry_deserializers: typing.List[typing.Callable],
obj: typing.Sequence,
):
return type(obj)(_deserialize(deserializer, entry) for entry, deserializer in zip(obj, entry_deserializers))
_DESERIALIZE_MAPPING = {
datetime: _deserialize_datetime,
date: _deserialize_date,
time: _deserialize_time,
bytes: _deserialize_bytes,
bytearray: _deserialize_bytes,
timedelta: _deserialize_duration,
typing.Any: lambda x: x,
decimal.Decimal: _deserialize_decimal,
datetime.datetime: SdkDecoder.deserialize_datetime,
datetime.date: SdkDecoder.deserialize_date,
datetime.time: SdkDecoder.deserialize_time,
bytes: SdkDecoder.deserialize_bytes,
bytearray: SdkDecoder.deserialize_bytes,
datetime.timedelta: SdkDecoder.deserialize_duration,
typing.Any: SdkDecoder.deserialize_any,
decimal.Decimal: SdkDecoder.deserialize_decimal,
}
_DESERIALIZE_MAPPING_WITHFORMAT = {
"rfc3339": _deserialize_datetime,
"rfc7231": _deserialize_datetime_rfc7231,
"unix-timestamp": _deserialize_datetime_unix_timestamp,
"base64": _deserialize_bytes,
"base64url": _deserialize_bytes_base64,
"rfc3339": SdkDecoder.deserialize_datetime,
"rfc7231": functools.partial(SdkDecoder.deserialize_datetime, encode="rfc7231"),
"base64": SdkDecoder.deserialize_bytes,
"base64url": functools.partial(SdkDecoder.deserialize_bytes, encode="base64url")
}
def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
if rf and rf._format:
return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
@ -323,17 +348,9 @@ def _get_type_alias_type(module_name: str, alias_name: str):
def _get_model(module_name: str, model_name: str):
models = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, type)
}
models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
module_end = module_name.rsplit(".", 1)[0]
models.update({
k: v
for k, v in sys.modules[module_end].__dict__.items()
if isinstance(v, type)
})
models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)})
if isinstance(model_name, str):
model_name = model_name.split(".")[-1]
if model_name not in models:
@ -385,16 +402,13 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
return default
@typing.overload
def pop(self, key: str) -> typing.Any:
...
def pop(self, key: str) -> typing.Any: ...
@typing.overload
def pop(self, key: str, default: _T) -> _T:
...
def pop(self, key: str, default: _T) -> _T: ...
@typing.overload
def pop(self, key: str, default: typing.Any) -> typing.Any:
...
def pop(self, key: str, default: typing.Any) -> typing.Any: ...
def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
if default is _UNSET:
@ -411,12 +425,10 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
self._data.update(*args, **kwargs)
@typing.overload
def setdefault(self, key: str, default: None = None) -> None:
...
def setdefault(self, key: str, default: None = None) -> None: ...
@typing.overload
def setdefault(self, key: str, default: typing.Any) -> typing.Any:
...
def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
if default is _UNSET:
@ -549,7 +561,9 @@ class Model(_MyMutableMapping):
@classmethod
def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
for v in cls.__dict__.values():
if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
if (
isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators
): # pylint: disable=protected-access
return v._rest_name # pylint: disable=protected-access
return None
@ -559,9 +573,7 @@ class Model(_MyMutableMapping):
return cls(data)
discriminator = cls._get_discriminator(exist_discriminators)
exist_discriminators.append(discriminator)
mapped_cls = cls.__mapping__.get(
data.get(discriminator), cls
) # pyright: ignore # pylint: disable=no-member
mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore # pylint: disable=no-member
if mapped_cls == cls:
return cls(data)
return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
@ -582,7 +594,9 @@ class Model(_MyMutableMapping):
continue
is_multipart_file_input = False
try:
is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
is_multipart_file_input = next(
rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k
)._is_multipart_file_input
except StopIteration:
pass
result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
@ -593,67 +607,11 @@ class Model(_MyMutableMapping):
if v is None or isinstance(v, _Null):
return None
if isinstance(v, (list, tuple, set)):
return type(v)(
Model._as_dict_value(x, exclude_readonly=exclude_readonly)
for x in v
)
return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
if isinstance(v, dict):
return {
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
for dk, dv in v.items()
}
return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
if _is_model(obj):
return obj
return _deserialize(model_deserializer, obj)
def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
if obj is None:
return obj
return _deserialize_with_callable(if_obj_deserializer, obj)
def _deserialize_with_union(deserializers, obj):
for deserializer in deserializers:
try:
return _deserialize(deserializer, obj)
except DeserializationError:
pass
raise DeserializationError()
def _deserialize_dict(
value_deserializer: typing.Optional[typing.Callable],
module: typing.Optional[str],
obj: typing.Dict[typing.Any, typing.Any],
):
if obj is None:
return obj
return {
k: _deserialize(value_deserializer, v, module)
for k, v in obj.items()
}
def _deserialize_multiple_sequence(
entry_deserializers: typing.List[typing.Optional[typing.Callable]],
module: typing.Optional[str],
obj,
):
if obj is None:
return obj
return type(obj)(
_deserialize(deserializer, entry, module)
for entry, deserializer in zip(obj, entry_deserializers)
)
def _deserialize_sequence(
deserializer: typing.Optional[typing.Callable],
module: typing.Optional[str],
obj,
):
if obj is None:
return obj
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
annotation: typing.Any,
@ -693,27 +651,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
except AttributeError:
pass
# is it optional?
try:
if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
)
return functools.partial(_deserialize_with_optional, if_obj_deserializer)
except AttributeError:
pass
if getattr(annotation, "__origin__", None) is typing.Union:
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
deserializers = [
deserializers: typing.List[typing.Callable] = [
_get_deserialize_callable_from_annotation(arg, module, rf)
for arg in sorted(
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
)
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
)
]
return functools.partial(_deserialize_with_union, deserializers)
return functools.partial(SdkDecoder.deserialize_union, deserializers)
try:
if annotation._name == "Dict": # pyright: ignore
@ -721,11 +668,9 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
annotation.__args__[1], module, rf # pyright: ignore
)
return functools.partial(
_deserialize_dict,
SdkDecoder.deserialize_dict,
value_deserializer,
module,
)
except (AttributeError, IndexError):
pass
@ -733,18 +678,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
if len(annotation.__args__) > 1: # pyright: ignore
entry_deserializers = [
_get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
entry_deserializers: typing.List[typing.Callable] = [
_get_deserialize_callable_from_annotation(dt, module, rf)
for dt in annotation.__args__ # pyright: ignore
]
return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module)
deserializer = _get_deserialize_callable_from_annotation(
return functools.partial(SdkDecoder.deserialize_sequence, entry_deserializers)
deserializer: typing.Callable = _get_deserialize_callable_from_annotation(
annotation.__args__[0], module, rf # pyright: ignore
)
return functools.partial(_deserialize_sequence, deserializer, module)
return functools.partial(SdkDecoder.deserialize_sequence, [deserializer])
except (TypeError, IndexError, AttributeError, SyntaxError):
pass
@ -876,7 +819,14 @@ def rest_field(
format: typing.Optional[str] = None,
is_multipart_file_input: bool = False,
) -> typing.Any:
return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)
return _RestField(
name=name,
type=type,
visibility=visibility,
default=default,
format=format,
is_multipart_file_input=is_multipart_file_input,
)
def rest_discriminator(