initial try to create deserializer class
This commit is contained in:
Родитель
3ebf8dcc62
Коммит
01635e1889
|
@ -16,7 +16,7 @@ import re
|
|||
import typing
|
||||
import enum
|
||||
import email.utils
|
||||
from datetime import datetime, date, time, timedelta, timezone
|
||||
import datetime
|
||||
from json import JSONEncoder
|
||||
from typing_extensions import Self
|
||||
import isodate
|
||||
|
@ -38,7 +38,7 @@ TZ_UTC = timezone.utc
|
|||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
def _timedelta_as_isostr(td: timedelta) -> str:
|
||||
def _timedelta_as_isostr(td: datetime.timedelta) -> str:
|
||||
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S'
|
||||
|
||||
Function adapted from the Tin Can Python project: https://github.com/RusticiSoftware/TinCanPython
|
||||
|
@ -170,15 +170,14 @@ _VALID_RFC7231 = re.compile(
|
|||
r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
|
||||
def _deserialize_datetime(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime:
|
||||
"""Deserialize ISO-8601 formatted string into Datetime object.
|
||||
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: ~datetime.datetime
|
||||
:returns: The datetime object from that input
|
||||
"""
|
||||
if isinstance(attr, datetime):
|
||||
if isinstance(attr, datetime.datetime):
|
||||
# i'm already deserialized
|
||||
return attr
|
||||
attr = attr.upper()
|
||||
|
@ -203,15 +202,14 @@ def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
|
|||
raise OverflowError("Hit max or min date")
|
||||
return date_obj
|
||||
|
||||
|
||||
def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime:
|
||||
def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime.datetime]) -> datetime.datetime:
|
||||
"""Deserialize RFC7231 formatted string into Datetime object.
|
||||
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: ~datetime.datetime
|
||||
:returns: The datetime object from that input
|
||||
"""
|
||||
if isinstance(attr, datetime):
|
||||
if isinstance(attr, datetime.datetime):
|
||||
# i'm already deserialized
|
||||
return attr
|
||||
match = _VALID_RFC7231.match(attr)
|
||||
|
@ -220,91 +218,118 @@ def _deserialize_datetime_rfc7231(attr: typing.Union[str, datetime]) -> datetime
|
|||
|
||||
return email.utils.parsedate_to_datetime(attr)
|
||||
|
||||
class SdkDecoder:
|
||||
@staticmethod
|
||||
def deserialize_datetime(
|
||||
attr: typing.Union[str, datetime.datetime],
|
||||
*,
|
||||
encode: typing.Union[typing.Literal["rfc3339"], typing.Literal["rfc7231"]] = "rfc3339",
|
||||
) -> datetime.datetime:
|
||||
if encode == "rfc3339":
|
||||
return _deserialize_datetime(attr)
|
||||
return _deserialize_datetime_rfc7231(attr)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_date(attr: typing.Union[str, datetime.date]) -> datetime.date:
|
||||
"""Deserialize ISO-8601 formatted string into Date object.
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: date
|
||||
:returns: The date object from that input
|
||||
"""
|
||||
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
|
||||
if isinstance(attr, datetime.date):
|
||||
return attr
|
||||
return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def deserialize_bytes(
|
||||
attr: typing.Union[str, bytes],
|
||||
*,
|
||||
encode: typing.Union[typing.Literal["base64"], typing.Literal["base64url"]] = "base64",
|
||||
) -> bytes:
|
||||
if isinstance(attr, (bytes, bytearray)):
|
||||
return attr
|
||||
if encode == "base64url":
|
||||
padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
|
||||
attr = attr + padding # type: ignore
|
||||
encoded = attr.replace("-", "+").replace("_", "/")
|
||||
else:
|
||||
encoded = attr
|
||||
return bytes(base64.b64decode(encoded))
|
||||
|
||||
def _deserialize_datetime_unix_timestamp(attr: typing.Union[float, datetime]) -> datetime:
|
||||
"""Deserialize unix timestamp into Datetime object.
|
||||
|
||||
@staticmethod
|
||||
def deserialize_duration(attr: str) -> datetime.timedelta:
|
||||
if isinstance(attr, datetime.timedelta):
|
||||
return attr
|
||||
return isodate.parse_duration(attr)
|
||||
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: ~datetime.datetime
|
||||
:returns: The datetime object from that input
|
||||
"""
|
||||
if isinstance(attr, datetime):
|
||||
# i'm already deserialized
|
||||
@staticmethod
|
||||
def deserialize_time(attr: typing.Union[str, datetime.time]) -> datetime.time:
|
||||
"""Deserialize ISO-8601 formatted string into time object.
|
||||
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: datetime.time
|
||||
:returns: The time object from that input
|
||||
"""
|
||||
if isinstance(attr, datetime.time):
|
||||
return attr
|
||||
return isodate.parse_time(attr)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_any(attr: typing.Any) -> typing.Any:
|
||||
return attr
|
||||
return datetime.fromtimestamp(attr, TZ_UTC)
|
||||
|
||||
|
||||
def _deserialize_date(attr: typing.Union[str, date]) -> date:
|
||||
"""Deserialize ISO-8601 formatted string into Date object.
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: date
|
||||
:returns: The date object from that input
|
||||
"""
|
||||
# This must NOT use defaultmonth/defaultday. Using None ensure this raises an exception.
|
||||
if isinstance(attr, date):
|
||||
return attr
|
||||
return isodate.parse_date(attr, defaultmonth=None, defaultday=None) # type: ignore
|
||||
|
||||
|
||||
def _deserialize_time(attr: typing.Union[str, time]) -> time:
|
||||
"""Deserialize ISO-8601 formatted string into time object.
|
||||
|
||||
:param str attr: response string to be deserialized.
|
||||
:rtype: datetime.time
|
||||
:returns: The time object from that input
|
||||
"""
|
||||
if isinstance(attr, time):
|
||||
return attr
|
||||
return isodate.parse_time(attr)
|
||||
|
||||
|
||||
def _deserialize_bytes(attr):
|
||||
if isinstance(attr, (bytes, bytearray)):
|
||||
return attr
|
||||
return bytes(base64.b64decode(attr))
|
||||
|
||||
|
||||
def _deserialize_bytes_base64(attr):
|
||||
if isinstance(attr, (bytes, bytearray)):
|
||||
return attr
|
||||
padding = "=" * (3 - (len(attr) + 3) % 4) # type: ignore
|
||||
attr = attr + padding # type: ignore
|
||||
encoded = attr.replace("-", "+").replace("_", "/")
|
||||
return bytes(base64.b64decode(encoded))
|
||||
|
||||
|
||||
def _deserialize_duration(attr):
|
||||
if isinstance(attr, timedelta):
|
||||
return attr
|
||||
return isodate.parse_duration(attr)
|
||||
|
||||
|
||||
def _deserialize_decimal(attr):
|
||||
if isinstance(attr, decimal.Decimal):
|
||||
return attr
|
||||
return decimal.Decimal(str(attr))
|
||||
|
||||
@staticmethod
|
||||
def deserialize_decimal(attr: str) -> decimal.Decimal:
|
||||
if isinstance(attr, decimal.Decimal):
|
||||
return attr
|
||||
return decimal.Decimal(str(attr))
|
||||
|
||||
@staticmethod
|
||||
def deserialize_union(deserializers: typing.List[typing.Callable], attr: typing.Any) -> typing.Any:
|
||||
for deserializer in deserializers:
|
||||
try:
|
||||
return _deserialize(deserializer, attr)
|
||||
except DeserializationError:
|
||||
pass
|
||||
raise DeserializationError()
|
||||
|
||||
@staticmethod
|
||||
def deserialize_dict(
|
||||
value_deserializer: typing.Optional[typing.Callable],
|
||||
attr: typing.Dict[typing.Any, typing.Any],
|
||||
):
|
||||
if attr is None:
|
||||
return attr
|
||||
return {k: _deserialize(value_deserializer, v) for k, v in attr.items()}
|
||||
|
||||
@staticmethod
|
||||
def deserialize_sequence(
|
||||
entry_deserializers: typing.List[typing.Callable],
|
||||
obj: typing.Sequence,
|
||||
):
|
||||
return type(obj)(_deserialize(deserializer, entry) for entry, deserializer in zip(obj, entry_deserializers))
|
||||
|
||||
|
||||
_DESERIALIZE_MAPPING = {
|
||||
datetime: _deserialize_datetime,
|
||||
date: _deserialize_date,
|
||||
time: _deserialize_time,
|
||||
bytes: _deserialize_bytes,
|
||||
bytearray: _deserialize_bytes,
|
||||
timedelta: _deserialize_duration,
|
||||
typing.Any: lambda x: x,
|
||||
decimal.Decimal: _deserialize_decimal,
|
||||
datetime.datetime: SdkDecoder.deserialize_datetime,
|
||||
datetime.date: SdkDecoder.deserialize_date,
|
||||
datetime.time: SdkDecoder.deserialize_time,
|
||||
bytes: SdkDecoder.deserialize_bytes,
|
||||
bytearray: SdkDecoder.deserialize_bytes,
|
||||
datetime.timedelta: SdkDecoder.deserialize_duration,
|
||||
typing.Any: SdkDecoder.deserialize_any,
|
||||
decimal.Decimal: SdkDecoder.deserialize_decimal,
|
||||
}
|
||||
|
||||
_DESERIALIZE_MAPPING_WITHFORMAT = {
|
||||
"rfc3339": _deserialize_datetime,
|
||||
"rfc7231": _deserialize_datetime_rfc7231,
|
||||
"unix-timestamp": _deserialize_datetime_unix_timestamp,
|
||||
"base64": _deserialize_bytes,
|
||||
"base64url": _deserialize_bytes_base64,
|
||||
"rfc3339": SdkDecoder.deserialize_datetime,
|
||||
"rfc7231": functools.partial(SdkDecoder.deserialize_datetime, encode="rfc7231"),
|
||||
"base64": SdkDecoder.deserialize_bytes,
|
||||
"base64url": functools.partial(SdkDecoder.deserialize_bytes, encode="base64url")
|
||||
}
|
||||
|
||||
|
||||
def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
|
||||
if rf and rf._format:
|
||||
return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
|
||||
|
@ -323,17 +348,9 @@ def _get_type_alias_type(module_name: str, alias_name: str):
|
|||
|
||||
|
||||
def _get_model(module_name: str, model_name: str):
|
||||
models = {
|
||||
k: v
|
||||
for k, v in sys.modules[module_name].__dict__.items()
|
||||
if isinstance(v, type)
|
||||
}
|
||||
models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
|
||||
module_end = module_name.rsplit(".", 1)[0]
|
||||
models.update({
|
||||
k: v
|
||||
for k, v in sys.modules[module_end].__dict__.items()
|
||||
if isinstance(v, type)
|
||||
})
|
||||
models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)})
|
||||
if isinstance(model_name, str):
|
||||
model_name = model_name.split(".")[-1]
|
||||
if model_name not in models:
|
||||
|
@ -385,16 +402,13 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
|
|||
return default
|
||||
|
||||
@typing.overload
|
||||
def pop(self, key: str) -> typing.Any:
|
||||
...
|
||||
def pop(self, key: str) -> typing.Any: ...
|
||||
|
||||
@typing.overload
|
||||
def pop(self, key: str, default: _T) -> _T:
|
||||
...
|
||||
def pop(self, key: str, default: _T) -> _T: ...
|
||||
|
||||
@typing.overload
|
||||
def pop(self, key: str, default: typing.Any) -> typing.Any:
|
||||
...
|
||||
def pop(self, key: str, default: typing.Any) -> typing.Any: ...
|
||||
|
||||
def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
|
||||
if default is _UNSET:
|
||||
|
@ -411,12 +425,10 @@ class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=uns
|
|||
self._data.update(*args, **kwargs)
|
||||
|
||||
@typing.overload
|
||||
def setdefault(self, key: str, default: None = None) -> None:
|
||||
...
|
||||
def setdefault(self, key: str, default: None = None) -> None: ...
|
||||
|
||||
@typing.overload
|
||||
def setdefault(self, key: str, default: typing.Any) -> typing.Any:
|
||||
...
|
||||
def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
|
||||
|
||||
def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
|
||||
if default is _UNSET:
|
||||
|
@ -549,7 +561,9 @@ class Model(_MyMutableMapping):
|
|||
@classmethod
|
||||
def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
|
||||
for v in cls.__dict__.values():
|
||||
if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
|
||||
if (
|
||||
isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators
|
||||
): # pylint: disable=protected-access
|
||||
return v._rest_name # pylint: disable=protected-access
|
||||
return None
|
||||
|
||||
|
@ -559,9 +573,7 @@ class Model(_MyMutableMapping):
|
|||
return cls(data)
|
||||
discriminator = cls._get_discriminator(exist_discriminators)
|
||||
exist_discriminators.append(discriminator)
|
||||
mapped_cls = cls.__mapping__.get(
|
||||
data.get(discriminator), cls
|
||||
) # pyright: ignore # pylint: disable=no-member
|
||||
mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore # pylint: disable=no-member
|
||||
if mapped_cls == cls:
|
||||
return cls(data)
|
||||
return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access
|
||||
|
@ -582,7 +594,9 @@ class Model(_MyMutableMapping):
|
|||
continue
|
||||
is_multipart_file_input = False
|
||||
try:
|
||||
is_multipart_file_input = next(rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k)._is_multipart_file_input
|
||||
is_multipart_file_input = next(
|
||||
rf for rf in self._attr_to_rest_field.values() if rf._rest_name == k
|
||||
)._is_multipart_file_input
|
||||
except StopIteration:
|
||||
pass
|
||||
result[k] = v if is_multipart_file_input else Model._as_dict_value(v, exclude_readonly=exclude_readonly)
|
||||
|
@ -593,67 +607,11 @@ class Model(_MyMutableMapping):
|
|||
if v is None or isinstance(v, _Null):
|
||||
return None
|
||||
if isinstance(v, (list, tuple, set)):
|
||||
return type(v)(
|
||||
Model._as_dict_value(x, exclude_readonly=exclude_readonly)
|
||||
for x in v
|
||||
)
|
||||
return type(v)(Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v)
|
||||
if isinstance(v, dict):
|
||||
return {
|
||||
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
|
||||
for dk, dv in v.items()
|
||||
}
|
||||
return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()}
|
||||
return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v
|
||||
|
||||
def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj):
|
||||
if _is_model(obj):
|
||||
return obj
|
||||
return _deserialize(model_deserializer, obj)
|
||||
|
||||
def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Callable], obj):
|
||||
if obj is None:
|
||||
return obj
|
||||
return _deserialize_with_callable(if_obj_deserializer, obj)
|
||||
|
||||
def _deserialize_with_union(deserializers, obj):
|
||||
for deserializer in deserializers:
|
||||
try:
|
||||
return _deserialize(deserializer, obj)
|
||||
except DeserializationError:
|
||||
pass
|
||||
raise DeserializationError()
|
||||
|
||||
def _deserialize_dict(
|
||||
value_deserializer: typing.Optional[typing.Callable],
|
||||
module: typing.Optional[str],
|
||||
obj: typing.Dict[typing.Any, typing.Any],
|
||||
):
|
||||
if obj is None:
|
||||
return obj
|
||||
return {
|
||||
k: _deserialize(value_deserializer, v, module)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
|
||||
def _deserialize_multiple_sequence(
|
||||
entry_deserializers: typing.List[typing.Optional[typing.Callable]],
|
||||
module: typing.Optional[str],
|
||||
obj,
|
||||
):
|
||||
if obj is None:
|
||||
return obj
|
||||
return type(obj)(
|
||||
_deserialize(deserializer, entry, module)
|
||||
for entry, deserializer in zip(obj, entry_deserializers)
|
||||
)
|
||||
|
||||
def _deserialize_sequence(
|
||||
deserializer: typing.Optional[typing.Callable],
|
||||
module: typing.Optional[str],
|
||||
obj,
|
||||
):
|
||||
if obj is None:
|
||||
return obj
|
||||
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
|
||||
|
||||
def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
|
||||
annotation: typing.Any,
|
||||
|
@ -693,27 +651,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
# is it optional?
|
||||
try:
|
||||
if any(a for a in annotation.__args__ if a == type(None)): # pyright: ignore
|
||||
if_obj_deserializer = _get_deserialize_callable_from_annotation(
|
||||
next(a for a in annotation.__args__ if a != type(None)), module, rf # pyright: ignore
|
||||
)
|
||||
|
||||
return functools.partial(_deserialize_with_optional, if_obj_deserializer)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if getattr(annotation, "__origin__", None) is typing.Union:
|
||||
# initial ordering is we make `string` the last deserialization option, because it is often them most generic
|
||||
deserializers = [
|
||||
deserializers: typing.List[typing.Callable] = [
|
||||
_get_deserialize_callable_from_annotation(arg, module, rf)
|
||||
for arg in sorted(
|
||||
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
|
||||
)
|
||||
annotation.__args__, key=lambda x: hasattr(x, "__name__") and x.__name__ == "str" # pyright: ignore
|
||||
)
|
||||
]
|
||||
|
||||
return functools.partial(_deserialize_with_union, deserializers)
|
||||
return functools.partial(SdkDecoder.deserialize_union, deserializers)
|
||||
|
||||
try:
|
||||
if annotation._name == "Dict": # pyright: ignore
|
||||
|
@ -721,11 +668,9 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
|
|||
annotation.__args__[1], module, rf # pyright: ignore
|
||||
)
|
||||
|
||||
|
||||
return functools.partial(
|
||||
_deserialize_dict,
|
||||
SdkDecoder.deserialize_dict,
|
||||
value_deserializer,
|
||||
module,
|
||||
)
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
|
@ -733,18 +678,16 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
|
|||
if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore
|
||||
if len(annotation.__args__) > 1: # pyright: ignore
|
||||
|
||||
|
||||
entry_deserializers = [
|
||||
_get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore
|
||||
entry_deserializers: typing.List[typing.Callable] = [
|
||||
_get_deserialize_callable_from_annotation(dt, module, rf)
|
||||
for dt in annotation.__args__ # pyright: ignore
|
||||
]
|
||||
return functools.partial(_deserialize_multiple_sequence, entry_deserializers, module)
|
||||
deserializer = _get_deserialize_callable_from_annotation(
|
||||
return functools.partial(SdkDecoder.deserialize_sequence, entry_deserializers)
|
||||
deserializer: typing.Callable = _get_deserialize_callable_from_annotation(
|
||||
annotation.__args__[0], module, rf # pyright: ignore
|
||||
)
|
||||
|
||||
|
||||
|
||||
return functools.partial(_deserialize_sequence, deserializer, module)
|
||||
return functools.partial(SdkDecoder.deserialize_sequence, [deserializer])
|
||||
except (TypeError, IndexError, AttributeError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
@ -876,7 +819,14 @@ def rest_field(
|
|||
format: typing.Optional[str] = None,
|
||||
is_multipart_file_input: bool = False,
|
||||
) -> typing.Any:
|
||||
return _RestField(name=name, type=type, visibility=visibility, default=default, format=format, is_multipart_file_input=is_multipart_file_input)
|
||||
return _RestField(
|
||||
name=name,
|
||||
type=type,
|
||||
visibility=visibility,
|
||||
default=default,
|
||||
format=format,
|
||||
is_multipart_file_input=is_multipart_file_input,
|
||||
)
|
||||
|
||||
|
||||
def rest_discriminator(
|
||||
|
|
Загрузка…
Ссылка в новой задаче