Rust: simplify rust doc test annotation

This commit is contained in:
Paolo Tranquilli 2024-09-09 08:59:17 +02:00
Родитель 928f3f11f1
Коммит 3cd8aaf4b0
6 изменённых файлов: 20 добавлений и 37 удалений

Просмотреть файл

@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
else:
table_name = inflection.tableize(table_name)
args = dict(
field_name=p.name + ("_" if p.name in rust.keywords else ""),
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,

Просмотреть файл

@ -1,5 +1,6 @@
import dataclasses
import typing
import inflection
from misc.codegen.loaders import schemaloader
from . import qlgen
@ -15,19 +16,7 @@ class Param:
@dataclasses.dataclass
class Function:
name: str
generic_params: list[Param]
params: list[Param]
return_type: str
def __post_init__(self):
if self.generic_params:
self.generic_params[0].first = True
if self.params:
self.params[0].first = True
@property
def has_generic_params(self) -> bool:
return bool(self.generic_params)
signature: str
@dataclasses.dataclass
@ -48,27 +37,28 @@ def generate(opts, renderer):
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc
):
not cls.doc):
continue
fn = cls.rust_doc_test_function
if fn:
generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"]
params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()]
fn = Function(fn.name, generic_params, params, fn.return_type)
code = []
adding_code = False
has_code = False
for line in cls.doc:
match line, adding_code:
case "```", _:
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.rust_doc_test_function
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs"
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)

Просмотреть файл

@ -57,6 +57,11 @@ keywords = {
"try",
}
def avoid_keywords(s: str) -> str:
return s + "_" if s in keywords else s
_field_overrides = [
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
@ -82,8 +87,7 @@ class Field:
first: bool = False
def __post_init__(self):
if self.field_name in keywords:
self.field_name += "_"
self.field_name = avoid_keywords(self.field_name)
@property
def type(self) -> str:

Просмотреть файл

@ -203,10 +203,3 @@ def split_doc(doc):
while trimmed and not trimmed[0]:
trimmed.pop(0)
return trimmed
@dataclass
class FunctionInfo:
name: str
params: dict[str, str]
return_type: str

Просмотреть файл

@ -159,11 +159,7 @@ _Pragma("cpp_skip")
_Pragma("rust_skip_doc_test")
rust.doc_test_function = lambda name, *, lifetimes=(), return_type="()", **kwargs: _annotate(
rust_doc_test_function=_schema.FunctionInfo(name,
params={f"'{lifetime}": "" for lifetime in lifetimes} | kwargs,
return_type=return_type)
)
rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)
def group(name: str = "") -> _ClassDecorator:

Просмотреть файл

@ -1,7 +1,7 @@
// generated by {{generator}}
{{#function}}
fn {{name}}{{#has_generic_params}}<{{#generic_params}}{{^first}}, {{/first}}{{name}}{{#type}}: {{.}}{{/type}}{{/generic_params}}>{{/has_generic_params}}({{#params}}{{^first}}, {{/first}}{{name}}: {{type}}{{/params}}) -> {{return_type}} {
fn {{name}}{{signature}} {
{{/function}}
{{code}}
{{#function}}