Swift: make C++ code generation language agnostic

This commit is contained in:
Paolo Tranquilli 2023-02-23 12:25:16 +01:00
Родитель feb4e60c4b
Коммит 6d192cdcc1
12 изменённых файлов: 77 добавлений и 55 удалений

Просмотреть файл

@ -5,7 +5,13 @@ load("//misc/bazel:pkg_runfiles.bzl", "pkg_runfiles")
filegroup( filegroup(
name = "schema", name = "schema",
srcs = ["schema.py"] + glob(["*.dbscheme"]), srcs = ["schema.py"],
visibility = ["//swift:__subpackages__"],
)
filegroup(
name = "schema_includes",
srcs = glob(["*.dbscheme"]),
visibility = ["//swift:__subpackages__"], visibility = ["//swift:__subpackages__"],
) )

Просмотреть файл

@ -60,12 +60,19 @@ def _parse_args() -> argparse.Namespace:
p.add_argument("--generated-registry", p.add_argument("--generated-registry",
help="registry file containing information about checked-in generated code"), help="registry file containing information about checked-in generated code"),
] ]
p.add_argument("--script-name",
help="script name to put in header comments of generated files. By default, the path of this "
"script relative to the root directory")
p.add_argument("--trap-library",
help="path to the trap library from an include directory, required if generating C++ trap bindings"),
p.add_argument("--ql-format", action="store_true", default=True, p.add_argument("--ql-format", action="store_true", default=True,
help="use codeql to autoformat QL files (which is the default)") help="use codeql to autoformat QL files (which is the default)")
p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files") p.add_argument("--no-ql-format", action="store_false", dest="ql_format", help="do not format QL files")
p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)") p.add_argument("--codeql-binary", default="codeql", help="command to use for QL formatting (default %(default)s)")
p.add_argument("--force", "-f", action="store_true", p.add_argument("--force", "-f", action="store_true",
help="generate all files without skipping unchanged files and overwriting modified ones"), help="generate all files without skipping unchanged files and overwriting modified ones")
p.add_argument("--use-current-directory", action="store_true",
help="do not consider paths as relative to --root-dir or the configuration directory")
opts = p.parse_args() opts = p.parse_args()
if opts.configuration_file is not None: if opts.configuration_file is not None:
with open(opts.configuration_file) as config: with open(opts.configuration_file) as config:
@ -75,16 +82,15 @@ def _parse_args() -> argparse.Namespace:
setattr(opts, flag, getattr(defaults, flag)) setattr(opts, flag, getattr(defaults, flag))
if opts.root_dir is None: if opts.root_dir is None:
opts.root_dir = opts.configuration_file.parent opts.root_dir = opts.configuration_file.parent
if opts.root_dir is None:
p.error("Either --configuration-file or --root-dir must be provided, or a codegen.conf file must be in a "
"containing directory")
if not opts.generate: if not opts.generate:
p.error("Nothing to do, specify --generate") p.error("Nothing to do, specify --generate")
# absolutize all paths relative to --root-dir # absolutize all paths
for arg in path_arguments: for arg in path_arguments:
path = getattr(opts, arg.dest) path = getattr(opts, arg.dest)
if path is not None: if path is not None:
setattr(opts, arg.dest, opts.root_dir / path) setattr(opts, arg.dest, _abspath(path) if opts.use_current_directory else (opts.root_dir / path))
if not opts.script_name:
opts.script_name = paths.exe_file.relative_to(opts.root_dir)
return opts return opts
@ -102,7 +108,7 @@ def run():
log_level = logging.INFO log_level = logging.INFO
logging.basicConfig(format="{levelname} {message}", style='{', level=log_level) logging.basicConfig(format="{levelname} {message}", style='{', level=log_level)
for target in opts.generate: for target in opts.generate:
generate(target, opts, render.Renderer(opts.root_dir)) generate(target, opts, render.Renderer(opts.script_name, opts.root_dir))
if __name__ == "__main__": if __name__ == "__main__":

Просмотреть файл

@ -95,4 +95,5 @@ def generate(opts, renderer):
out = opts.cpp_output out = opts.cpp_output
for dir, classes in processor.get_classes().items(): for dir, classes in processor.get_classes().items():
renderer.render(cpp.ClassList(classes, opts.schema, renderer.render(cpp.ClassList(classes, opts.schema,
include_parent=bool(dir)), out / dir / "TrapClasses") include_parent=bool(dir),
trap_library=opts.trap_library), out / dir / "TrapClasses")

Просмотреть файл

@ -125,8 +125,8 @@ def generate(opts, renderer):
data = schemaloader.load_file(input) data = schemaloader.load_file(input)
dbscheme = Scheme(src=input.relative_to(opts.root_dir), dbscheme = Scheme(src=input.name,
includes=get_includes(data, include_dir=input.parent, root_dir=opts.root_dir), includes=get_includes(data, include_dir=input.parent, root_dir=input.parent),
declarations=get_declarations(data)) declarations=get_declarations(data))
renderer.render(dbscheme, out) renderer.render(dbscheme, out)

Просмотреть файл

@ -72,6 +72,7 @@ def generate(opts, renderer):
assert opts.cpp_output assert opts.cpp_output
tag_graph = {} tag_graph = {}
out = opts.cpp_output out = opts.cpp_output
trap_library = opts.trap_library
traps = {pathlib.Path(): []} traps = {pathlib.Path(): []}
for e in dbschemeloader.iterload(opts.dbscheme): for e in dbschemeloader.iterload(opts.dbscheme):
@ -84,7 +85,8 @@ def generate(opts, renderer):
for dir, entries in traps.items(): for dir, entries in traps.items():
dir = dir or pathlib.Path() dir = dir or pathlib.Path()
renderer.render(cpp.TrapList(entries, opts.dbscheme), out / dir / "TrapEntries") relative_gen_dir = pathlib.Path(*[".." for _ in dir.parents])
renderer.render(cpp.TrapList(entries, opts.dbscheme, trap_library, relative_gen_dir), out / dir / "TrapEntries")
tags = [] tags = []
for tag in toposort_flatten(tag_graph): for tag in toposort_flatten(tag_graph):

Просмотреть файл

@ -1,3 +1,4 @@
import pathlib
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, ClassVar from typing import List, ClassVar
@ -110,6 +111,8 @@ class TrapList:
extensions = ["h", "cpp"] extensions = ["h", "cpp"]
traps: List[Trap] traps: List[Trap]
source: str source: str
trap_library_dir: pathlib.Path
gen_dir: pathlib.Path
@dataclass @dataclass
@ -156,4 +159,5 @@ class ClassList:
classes: List[Class] classes: List[Class]
source: str source: str
trap_library: str
include_parent: bool = False include_parent: bool = False

Просмотреть файл

@ -25,13 +25,10 @@ class Error(Exception):
class Renderer: class Renderer:
""" Template renderer using mustache templates in the `templates` directory """ """ Template renderer using mustache templates in the `templates` directory """
def __init__(self, root_dir: pathlib.Path): def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path):
self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u) self._r = pystache.Renderer(search_dirs=str(paths.templates_dir), escape=lambda u: u)
self._root_dir = root_dir self._root_dir = root_dir
try: self._generator = generator
self._generator = self._get_path(paths.exe_file)
except ValueError:
self._generator = paths.exe_file.name
def _get_path(self, file: pathlib.Path): def _get_path(self, file: pathlib.Path):
return file.relative_to(self._root_dir) return file.relative_to(self._root_dir)
@ -63,7 +60,7 @@ class Renderer:
def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path], def manage(self, generated: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False) -> "RenderManager": registry: pathlib.Path, force: bool = False) -> "RenderManager":
return RenderManager(self._root_dir, generated, stubs, registry, force) return RenderManager(self._generator, self._root_dir, generated, stubs, registry, force)
class RenderManager(Renderer): class RenderManager(Renderer):
@ -88,10 +85,10 @@ class RenderManager(Renderer):
pre: str pre: str
post: typing.Optional[str] = None post: typing.Optional[str] = None
def __init__(self, root_dir: pathlib.Path, generated: typing.Iterable[pathlib.Path], def __init__(self, generator: pathlib.Path, root_dir: pathlib.Path, generated: typing.Iterable[pathlib.Path],
stubs: typing.Iterable[pathlib.Path], stubs: typing.Iterable[pathlib.Path],
registry: pathlib.Path, force: bool = False): registry: pathlib.Path, force: bool = False):
super().__init__(root_dir) super().__init__(generator, root_dir)
self._registry_path = registry self._registry_path = registry
self._force = force self._force = force
self._hashes = {} self._hashes = {}

Просмотреть файл

@ -6,8 +6,8 @@
#include <optional> #include <optional>
#include <vector> #include <vector>
#include "swift/extractor/trap/TrapLabel.h" #include "{{trap_library}}/TrapLabel.h"
#include "swift/extractor/trap/TrapTagTraits.h" #include "{{trap_library}}/TrapTagTraits.h"
#include "./TrapEntries.h" #include "./TrapEntries.h"
{{#include_parent}} {{#include_parent}}
#include "../TrapClasses.h" #include "../TrapClasses.h"

Просмотреть файл

@ -5,9 +5,9 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "swift/extractor/trap/TrapLabel.h" #include "{{trap_library_dir}}/TrapLabel.h"
#include "swift/extractor/trap/TrapTagTraits.h" #include "{{trap_library_dir}}/TrapTagTraits.h"
#include "swift/extractor/trap/generated/TrapTags.h" #include "{{gen_dir}}/TrapTags.h"
namespace codeql { namespace codeql {
{{#traps}} {{#traps}}

Просмотреть файл

@ -30,7 +30,7 @@ def generate(opts, input, renderer):
def test_empty(generate): def test_empty(generate):
assert generate([]) == dbscheme.Scheme( assert generate([]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[], declarations=[],
) )
@ -43,10 +43,10 @@ def test_includes(input, opts, generate):
write(opts.schema.parent / i, i + " data") write(opts.schema.parent / i, i + " data")
assert generate([]) == dbscheme.Scheme( assert generate([]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[ includes=[
dbscheme.SchemeInclude( dbscheme.SchemeInclude(
src=schema_dir / i, src=pathlib.Path(i),
data=i + " data", data=i + " data",
) for i in includes ) for i in includes
], ],
@ -58,7 +58,7 @@ def test_empty_final_class(generate, dir_param):
assert generate([ assert generate([
schema.Class("Object", group=dir_param.input), schema.Class("Object", group=dir_param.input),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -78,7 +78,7 @@ def test_final_class_with_single_scalar_field(generate, dir_param):
schema.SingleProperty("foo", "bar"), schema.SingleProperty("foo", "bar"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -98,7 +98,7 @@ def test_final_class_with_single_class_field(generate, dir_param):
schema.SingleProperty("foo", "Bar"), schema.SingleProperty("foo", "Bar"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -118,7 +118,7 @@ def test_final_class_with_optional_field(generate, dir_param):
schema.OptionalProperty("foo", "bar"), schema.OptionalProperty("foo", "bar"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -146,7 +146,7 @@ def test_final_class_with_repeated_field(generate, property_cls, dir_param):
property_cls("foo", "bar"), property_cls("foo", "bar"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -174,7 +174,7 @@ def test_final_class_with_predicate_field(generate, dir_param):
schema.PredicateProperty("foo"), schema.PredicateProperty("foo"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -205,7 +205,7 @@ def test_final_class_with_more_fields(generate, dir_param):
schema.PredicateProperty("six"), schema.PredicateProperty("six"),
]), ]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Table( dbscheme.Table(
@ -259,7 +259,7 @@ def test_empty_class_with_derived(generate):
schema.Class(name="Left", bases=["Base"]), schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]), schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union( dbscheme.Union(
@ -290,7 +290,7 @@ def test_class_with_derived_and_single_property(generate, dir_param):
schema.Class(name="Left", bases=["Base"]), schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]), schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union( dbscheme.Union(
@ -330,7 +330,7 @@ def test_class_with_derived_and_optional_property(generate, dir_param):
schema.Class(name="Left", bases=["Base"]), schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]), schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union( dbscheme.Union(
@ -370,7 +370,7 @@ def test_class_with_derived_and_repeated_property(generate, dir_param):
schema.Class(name="Left", bases=["Base"]), schema.Class(name="Left", bases=["Base"]),
schema.Class(name="Right", bases=["Base"]), schema.Class(name="Right", bases=["Base"]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union( dbscheme.Union(
@ -432,7 +432,7 @@ def test_null_class(generate):
bases=["Base"], bases=["Base"],
), ),
], null="Null") == dbscheme.Scheme( ], null="Null") == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union( dbscheme.Union(
@ -514,7 +514,7 @@ def test_ipa_classes_ignored(generate):
schema.Class(name="B", ipa=schema.IpaInfo(from_class="A")), schema.Class(name="B", ipa=schema.IpaInfo(from_class="A")),
schema.Class(name="C", ipa=schema.IpaInfo(on_arguments={"x": "A"})), schema.Class(name="C", ipa=schema.IpaInfo(on_arguments={"x": "A"})),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[], declarations=[],
) )
@ -526,7 +526,7 @@ def test_ipa_derived_classes_ignored(generate):
schema.Class(name="B", bases=["A"], ipa=schema.IpaInfo()), schema.Class(name="B", bases=["A"], ipa=schema.IpaInfo()),
schema.Class(name="C", bases=["A"]), schema.Class(name="C", bases=["A"]),
]) == dbscheme.Scheme( ]) == dbscheme.Scheme(
src=schema_file, src=schema_file.name,
includes=[], includes=[],
declarations=[ declarations=[
dbscheme.Union("@a", ["@c"]), dbscheme.Union("@a", ["@c"]),

Просмотреть файл

@ -6,6 +6,8 @@ from swift.codegen.test.utils import *
import hashlib import hashlib
generator = "foo"
@pytest.fixture @pytest.fixture
def pystache_renderer_cls(): def pystache_renderer_cls():
@ -22,7 +24,7 @@ def pystache_renderer(pystache_renderer_cls):
@pytest.fixture @pytest.fixture
def sut(pystache_renderer): def sut(pystache_renderer):
return render.Renderer(paths.root_dir) return render.Renderer(generator, paths.root_dir)
def assert_file(file, text): def assert_file(file, text):
@ -53,7 +55,7 @@ def test_render(pystache_renderer, sut):
assert_file(output, text) assert_file(output, text)
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -72,7 +74,7 @@ def test_managed_render(pystache_renderer, sut):
assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n") assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -90,7 +92,7 @@ def test_managed_render_with_no_registry(pystache_renderer, sut):
assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n") assert_file(registry, f"some/output.txt {hash(text)} {hash(text)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -111,7 +113,7 @@ def test_managed_render_with_post_processing(pystache_renderer, sut):
assert_file(registry, f"some/output.txt {hash(text)} {hash(postprocessed_text)}\n") assert_file(registry, f"some/output.txt {hash(text)} {hash(postprocessed_text)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -149,7 +151,7 @@ def test_managed_render_with_skipping_of_generated_file(pystache_renderer, sut):
assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n") assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -171,7 +173,7 @@ def test_managed_render_with_skipping_of_stub_file(pystache_renderer, sut):
assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_processed_output)}\n") assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_processed_output)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -277,7 +279,7 @@ def test_render_with_extensions(pystache_renderer, sut):
sut.render(data, output) sut.render(data, output)
expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"] expected_templates = ["test_template_foo", "test_template_bar", "test_template_baz"]
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(t, data, generator=paths.exe_file.relative_to(paths.root_dir)) mock.call.render_name(t, data, generator=generator)
for t in expected_templates for t in expected_templates
] ]
for expected_output, expected_contents in zip(expected_outputs, rendered): for expected_output, expected_contents in zip(expected_outputs, rendered):
@ -301,7 +303,7 @@ def test_managed_render_with_force_not_skipping_generated_file(pystache_renderer
assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n") assert_file(registry, f"some/output.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]
@ -323,7 +325,7 @@ def test_managed_render_with_force_not_skipping_stub_file(pystache_renderer, sut
assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_output)}\n") assert_file(registry, f"some/stub.txt {hash(some_output)} {hash(some_output)}\n")
assert pystache_renderer.mock_calls == [ assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file.relative_to(paths.root_dir)), mock.call.render_name(data.template, data, generator=generator),
] ]

Просмотреть файл

@ -16,10 +16,14 @@ genrule(
cmd = " ".join([ cmd = " ".join([
"$(location //swift/codegen)", "$(location //swift/codegen)",
"--generate=dbscheme,trap,cpp", "--generate=dbscheme,trap,cpp",
"--dbscheme $(RULEDIR)/generated/swift.dbscheme", "--dbscheme=$(RULEDIR)/generated/swift.dbscheme",
"--cpp-output $(RULEDIR)/generated", "--cpp-output=$(RULEDIR)/generated",
"--trap-library=swift/extractor/trap",
"--use-current-dir",
"--schema=$(location //swift:schema)",
"--script-name=codegen/codegen.py",
]), ]),
exec_tools = ["//swift/codegen"], exec_tools = ["//swift/codegen", "//swift:schema"],
) )
filegroup( filegroup(