From 3cd8aaf4b0b02cdca84570f50f91d297b6bb696b Mon Sep 17 00:00:00 2001 From: Paolo Tranquilli Date: Mon, 9 Sep 2024 08:59:17 +0200 Subject: [PATCH] Rust: simplify rust doc test annotation --- misc/codegen/generators/rustgen.py | 2 +- misc/codegen/generators/rusttestgen.py | 32 +++++++------------ misc/codegen/lib/rust.py | 8 +++-- misc/codegen/lib/schema.py | 7 ---- misc/codegen/lib/schemadefs.py | 6 +--- .../codegen/templates/rust_test_code.mustache | 2 +- 6 files changed, 20 insertions(+), 37 deletions(-) diff --git a/misc/codegen/generators/rustgen.py b/misc/codegen/generators/rustgen.py index 81d9199b5eb..491997cdfad 100644 --- a/misc/codegen/generators/rustgen.py +++ b/misc/codegen/generators/rustgen.py @@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field: else: table_name = inflection.tableize(table_name) args = dict( - field_name=p.name + ("_" if p.name in rust.keywords else ""), + field_name=rust.avoid_keywords(p.name), base_type=_get_type(p.type), is_optional=p.is_optional, is_repeated=p.is_repeated, diff --git a/misc/codegen/generators/rusttestgen.py b/misc/codegen/generators/rusttestgen.py index 6c956776a9d..1fed5a30a32 100644 --- a/misc/codegen/generators/rusttestgen.py +++ b/misc/codegen/generators/rusttestgen.py @@ -1,5 +1,6 @@ import dataclasses import typing +import inflection from misc.codegen.loaders import schemaloader from . import qlgen @@ -15,19 +16,7 @@ class Param: @dataclasses.dataclass class Function: name: str - generic_params: list[Param] - params: list[Param] - return_type: str - - def __post_init__(self): - if self.generic_params: - self.generic_params[0].first = True - if self.params: - self.params[0].first = True - - @property - def has_generic_params(self) -> bool: - return bool(self.generic_params) + signature: str @dataclasses.dataclass @@ -48,27 +37,28 @@ def generate(opts, renderer): for cls in schema.classes.values(): if (qlgen.should_skip_qltest(cls, schema.classes) or "rust_skip_test_from_doc" in cls.pragmas or - not cls.doc - ): + not cls.doc): continue - fn = cls.rust_doc_test_function - if fn: - generic_params = [Param(k, v) for k, v in fn.params.items() if k[0].isupper() or k[0] == "'"] - params = [Param(k, v) for k, v in fn.params.items() if k[0].islower()] - fn = Function(fn.name, generic_params, params, fn.return_type) code = [] adding_code = False + has_code = False for line in cls.doc: match line, adding_code: case "```", _: adding_code = not adding_code + has_code = True case _, False: code.append(f"// {line}") case _, True: code.append(line) + if not has_code: + continue + test_name = inflection.underscore(cls.name) + signature = cls.rust_doc_test_function + fn = signature and Function(f"test_{test_name}", signature) if fn: indent = 4 * " " code = [indent + l for l in code] test_with = schema.classes[cls.test_with] if cls.test_with else cls - test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{cls.name.lower()}.rs" + test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs" renderer.render(TestCode(code="\n".join(code), function=fn), test) diff --git a/misc/codegen/lib/rust.py b/misc/codegen/lib/rust.py index d28a5f55a24..ac7bf4313d3 100644 --- a/misc/codegen/lib/rust.py +++ b/misc/codegen/lib/rust.py @@ -57,6 +57,11 @@ keywords = { "try", } + +def avoid_keywords(s: str) -> str: + return s + "_" if s in keywords else s + + _field_overrides = [ (re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}), ] @@ -82,8 +87,7 @@ class Field: first: bool = False def __post_init__(self): - if self.field_name in keywords: - self.field_name += "_" + self.field_name = avoid_keywords(self.field_name) @property def type(self) -> str: diff --git a/misc/codegen/lib/schema.py b/misc/codegen/lib/schema.py index a1358f97d50..37f0f53c5f5 100644 --- a/misc/codegen/lib/schema.py +++ b/misc/codegen/lib/schema.py @@ -203,10 +203,3 @@ def split_doc(doc): while trimmed and not trimmed[0]: trimmed.pop(0) return trimmed - - -@dataclass -class FunctionInfo: - name: str - params: dict[str, str] - return_type: str diff --git a/misc/codegen/lib/schemadefs.py b/misc/codegen/lib/schemadefs.py index 143e1ea3bb1..5f424b2bffc 100644 --- a/misc/codegen/lib/schemadefs.py +++ b/misc/codegen/lib/schemadefs.py @@ -159,11 +159,7 @@ _Pragma("cpp_skip") _Pragma("rust_skip_doc_test") -rust.doc_test_function = lambda name, *, lifetimes=(), return_type="()", **kwargs: _annotate( - rust_doc_test_function=_schema.FunctionInfo(name, - params={f"'{lifetime}": "" for lifetime in lifetimes} | kwargs, - return_type=return_type) -) +rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature) def group(name: str = "") -> _ClassDecorator: diff --git a/misc/codegen/templates/rust_test_code.mustache b/misc/codegen/templates/rust_test_code.mustache index 975d954349e..d38da6a353b 100644 --- a/misc/codegen/templates/rust_test_code.mustache +++ b/misc/codegen/templates/rust_test_code.mustache @@ -1,7 +1,7 @@ // generated by {{generator}} {{#function}} -fn {{name}}{{#has_generic_params}}<{{#generic_params}}{{^first}}, {{/first}}{{name}}{{#type}}: {{.}}{{/type}}{{/generic_params}}>{{/has_generic_params}}({{#params}}{{^first}}, {{/first}}{{name}}: {{type}}{{/params}}) -> {{return_type}} { +fn {{name}}{{signature}} { {{/function}} {{code}} {{#function}}