This commit is contained in:
Paolo Tranquilli 2024-10-15 14:21:11 +02:00
Родитель 248eb7f00c
Коммит bd08bc7923
4 изменённых файлов: 41 добавлений и 10 удалений

Просмотреть файл

@ -1,7 +1,7 @@
from typing import (
Callable as _Callable,
Dict as _Dict,
List as _List,
Iterable as _Iterable,
ClassVar as _ClassVar,
)
from misc.codegen.lib import schema as _schema
@ -279,7 +279,7 @@ _ = _PropertyAnnotation()
drop = object()
def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
"""
Add or modify schema annotations after a class has been defined previously.
@ -297,7 +297,7 @@ def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_
if replace_bases:
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
if add_bases:
annotated_cls.__bases__ = tuple(annotated_cls.__bases__) + tuple(add_bases)
annotated_cls.__bases__ += tuple(add_bases)
for a in dir(cls):
if a.startswith(_schema.inheritable_pragma_prefix):
setattr(annotated_cls, a, getattr(cls, a))

Просмотреть файл

@ -914,6 +914,36 @@ def test_annotate_replace_bases():
}
def test_annotate_add_bases():
@load
class data:
class Root:
pass
class A(Root):
pass
class B(Root):
pass
class C(Root):
pass
class Derived(A):
pass
@defs.annotate(Derived, add_bases=(B, C))
class _:
pass
assert data.classes == {
"Root": schema.Class("Root", derived={"A", "B", "C"}),
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
"B": schema.Class("B", bases=["Root"], derived={"Derived"}),
"C": schema.Class("C", bases=["Root"], derived={"Derived"}),
"Derived": schema.Class("Derived", bases=["A", "B", "C"]),
}
def test_annotate_drop_field():
@load
class data:

Просмотреть файл

@ -1741,13 +1741,6 @@ class _:
```
"""
class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child
@annotate(Function, add_bases=[Callable])
class _:
param_list: drop

Просмотреть файл

@ -63,3 +63,11 @@ class Unimplemented(Unextracted):
The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted.
"""
pass
class Callable(AstNode):
"""
A callable. Either a `Function` or a `ClosureExpr`.
"""
param_list: optional["ParamList"] | child
attrs: list["Attr"] | child