Swift: add unit tests to code generation

Tests can be run with
```
bazel test //swift/codegen:tests
```

Coverage can be checked installing `pytest-cov` and running
```
pytest --cov=swift/codegen swift/codegen/test
```
This commit is contained in:
Paolo Tranquilli 2022-04-26 18:22:40 +02:00
Родитель 2d05ea3519
Коммит f171ce6341
19 изменённых файлов: 1008 добавлений и 149 удалений

7
.github/workflows/swift-codegen.yml поставляемый
Просмотреть файл

@ -19,9 +19,14 @@ jobs:
cache: 'pip'
- uses: ./.github/actions/fetch-codeql
- uses: bazelbuild/setup-bazelisk@v2
- name: Check code generation
- name: Install dependencies
run: |
pip install -r swift/codegen/requirements.txt
- name: Run unit tests
run: |
bazel test //swift/codegen:tests --test_output=errors
- name: Check that code was generated
run: |
bazel run //swift/codegen
git add swift
git diff --exit-code --stat HEAD

Просмотреть файл

@ -40,3 +40,10 @@ repos:
language: system
entry: bazel run //swift/codegen
pass_filenames: false
- id: swift-codegen-unit-tests
name: Run Swift code generation unit tests
files: ^swift/codegen
language: system
entry: bazel test //swift/codegen:tests
pass_filenames: false

1
conftest.py Normal file
Просмотреть файл

@ -0,0 +1 @@
# this empty file adds the repo root to PYTHON_PATH when running pytest

Двоичные данные
swift/codegen/.coverage Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -1,4 +1,31 @@
py_binary(
name = "codegen",
srcs = glob(["**/*.py"]),
srcs = glob([
"lib/*.py",
"*.py",
]),
)
py_library(
name = "test_utils",
testonly = True,
srcs = ["test/utils.py"],
deps = [":codegen"],
)
[
py_test(
name = src[len("test/"):-len(".py")],
size = "small",
srcs = [src],
deps = [
":codegen",
":test_utils",
],
)
for src in glob(["test/test_*.py"])
]
test_suite(
name = "tests",
)

Просмотреть файл

@ -3,8 +3,8 @@ import pathlib
import inflection
from lib import paths, schema, generator
from lib.dbscheme import *
from swift.codegen.lib import paths, schema, generator
from swift.codegen.lib.dbscheme import *
log = logging.getLogger(__name__)
@ -60,7 +60,7 @@ def cls_to_dbscheme(cls: schema.Class):
def get_declarations(data: schema.Schema):
return [d for cls in data.classes.values() for d in cls_to_dbscheme(cls)]
return [d for cls in data.classes for d in cls_to_dbscheme(cls)]
def get_includes(data: schema.Schema, include_dir: pathlib.Path):
@ -73,11 +73,10 @@ def get_includes(data: schema.Schema, include_dir: pathlib.Path):
def generate(opts, renderer):
input = opts.schema.resolve()
out = opts.dbscheme.resolve()
input = opts.schema
out = opts.dbscheme
with open(input) as src:
data = schema.load(src)
data = schema.load(input)
dbscheme = DbScheme(src=input.relative_to(paths.swift_dir),
includes=get_includes(data, include_dir=input.parent),

Просмотреть файл

@ -10,13 +10,17 @@ from . import paths
def _init_options():
Option("--verbose", "-v", action="store_true")
Option("--schema", tags=["schema"], type=pathlib.Path, default=paths.swift_dir / "codegen/schema.yml")
Option("--dbscheme", tags=["dbscheme"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/swift.dbscheme")
Option("--ql-output", tags=["ql"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
Option("--ql-stub-output", tags=["ql"], type=pathlib.Path, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
Option("--schema", tags=["schema"], type=_abspath, default=paths.swift_dir / "codegen/schema.yml")
Option("--dbscheme", tags=["dbscheme"], type=_abspath, default=paths.swift_dir / "ql/lib/swift.dbscheme")
Option("--ql-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/generated")
Option("--ql-stub-output", tags=["ql"], type=_abspath, default=paths.swift_dir / "ql/lib/codeql/swift/elements")
Option("--codeql-binary", tags=["ql"], default="codeql")
def _abspath(x):
return pathlib.Path(x).resolve()
_options = collections.defaultdict(list)

Просмотреть файл

@ -5,13 +5,16 @@ import sys
import os
try:
_workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']) # <- means we are using bazel run
_workspace_dir = pathlib.Path(os.environ['BUILD_WORKSPACE_DIRECTORY']).resolve() # <- means we are using bazel run
swift_dir = _workspace_dir / 'swift'
lib_dir = swift_dir / 'codegen' / 'lib'
except KeyError:
_this_file = pathlib.Path(__file__).resolve()
swift_dir = _this_file.parents[2]
lib_dir = _this_file.parent
lib_dir = swift_dir / 'codegen' / 'lib'
templates_dir = lib_dir / 'templates'
exe_file = pathlib.Path(sys.argv[0]).resolve()
try:
exe_file = pathlib.Path(sys.argv[0]).resolve().relative_to(swift_dir)
except ValueError:
exe_file = pathlib.Path(sys.argv[0]).name

88
swift/codegen/lib/ql.py Normal file
Просмотреть файл

@ -0,0 +1,88 @@
import pathlib
from dataclasses import dataclass, field
from typing import List, ClassVar
import inflection
@dataclass
class QlParam:
param: str
type: str = None
first: bool = False
@dataclass
class QlProperty:
singular: str
type: str
tablename: str
tableparams: List[QlParam]
plural: str = None
params: List[QlParam] = field(default_factory=list)
first: bool = False
local_var: str = "x"
def __post_init__(self):
if self.params:
self.params[0].first = True
while self.local_var in (p.param for p in self.params):
self.local_var += "_"
assert self.tableparams
if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
self.tableparams = [QlParam(x) for x in self.tableparams]
self.tableparams[0].first = True
@property
def indefinite_article(self):
if self.plural:
return "An" if self.singular[0] in "AEIO" else "A"
@property
def type_is_class(self):
return self.type[0].isupper()
@dataclass
class QlClass:
template: ClassVar = 'ql_class'
name: str
bases: List[str] = field(default_factory=list)
final: bool = False
properties: List[QlProperty] = field(default_factory=list)
dir: pathlib.Path = pathlib.Path()
imports: List[str] = field(default_factory=list)
def __post_init__(self):
self.bases = sorted(self.bases)
if self.properties:
self.properties[0].first = True
@property
def db_id(self):
return "@" + inflection.underscore(self.name)
@property
def root(self):
return not self.bases
@property
def path(self):
return self.dir / self.name
@dataclass
class QlStub:
template: ClassVar = 'ql_stub'
name: str
base_import: str
@dataclass
class QlImportList:
template: ClassVar = 'ql_imports'
imports: List[str] = field(default_factory=list)

Просмотреть файл

@ -19,8 +19,7 @@ class Renderer:
""" Template renderer using mustache templates in the `templates` directory """
def __init__(self):
self.r = pystache.Renderer(search_dirs=str(paths.lib_dir / "templates"), escape=lambda u: u)
self.generator = paths.exe_file.relative_to(paths.swift_dir)
self._r = pystache.Renderer(search_dirs=str(paths.lib_dir / "templates"), escape=lambda u: u)
self.written = set()
def render(self, data, output: pathlib.Path):
@ -32,7 +31,7 @@ class Renderer:
"""
mnemonic = type(data).__name__
output.parent.mkdir(parents=True, exist_ok=True)
data = self.r.render_name(data.template, data, generator=self.generator)
data = self._r.render_name(data.template, data, generator=paths.exe_file)
with open(output, "w") as out:
out.write(data)
log.debug(f"generated {mnemonic} {output.name}")
@ -41,6 +40,5 @@ class Renderer:
def cleanup(self, existing):
""" Remove files in `existing` for which no `render` has been called """
for f in existing - self.written:
if f.is_file():
f.unlink()
log.info(f"removed {f.name}")
f.unlink(missing_ok=True)
log.info(f"removed {f.name}")

Просмотреть файл

@ -3,7 +3,6 @@
import pathlib
import re
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import List, Set, Dict, ClassVar
import yaml
@ -47,7 +46,7 @@ class Class:
@dataclass
class Schema:
classes: Dict[str, Class]
classes: List[Class]
includes: Set[str] = field(default_factory=set)
@ -65,6 +64,7 @@ def _parse_property(name, type):
class _DirSelector:
""" Default output subdirectory selector for generated QL files, based on the `_directories` global field"""
def __init__(self, dir_to_patterns):
self.selector = [(re.compile(p), pathlib.Path(d)) for d, p in dir_to_patterns]
self.selector.append((re.compile(""), pathlib.Path()))
@ -73,19 +73,19 @@ class _DirSelector:
return next(d for p, d in self.selector if p.search(name))
def load(file):
""" Parse the schema from `file` """
data = yaml.load(file, Loader=yaml.SafeLoader)
def load(path):
""" Parse the schema from the file at `path` """
with open(path) as input:
data = yaml.load(input, Loader=yaml.SafeLoader)
grouper = _DirSelector(data.get("_directories", {}).items())
ret = Schema(classes={cls: Class(cls, dir=grouper.get(cls)) for cls in data if not cls.startswith("_")},
includes=set(data.get("_includes", [])))
assert root_class_name not in ret.classes
ret.classes[root_class_name] = Class(root_class_name)
classes = {root_class_name: Class(root_class_name)}
assert root_class_name not in data
classes.update((cls, Class(cls, dir=grouper.get(cls))) for cls in data if not cls.startswith("_"))
for name, info in data.items():
if name.startswith("_"):
continue
assert name[0].isupper()
cls = ret.classes[name]
cls = classes[name]
for k, v in info.items():
if not k.startswith("_"):
cls.properties.append(_parse_property(k, v))
@ -94,11 +94,11 @@ def load(file):
v = [v]
for base in v:
cls.bases.add(base)
ret.classes[base].derived.add(name)
classes[base].derived.add(name)
elif k == "_dir":
cls.dir = pathlib.Path(v)
if not cls.bases:
cls.bases.add(root_class_name)
ret.classes[root_class_name].derived.add(name)
classes[root_class_name].derived.add(name)
return ret
return Schema(classes=list(classes.values()), includes=set(data.get("_includes", [])))

Просмотреть файл

@ -1,129 +1,43 @@
#!/usr/bin/env python3
import logging
import pathlib
import subprocess
from dataclasses import dataclass, field
from typing import List, ClassVar
import inflection
from lib import schema, paths, generator
from swift.codegen.lib import schema, paths, generator, ql
log = logging.getLogger(__name__)
@dataclass
class QlParam:
param: str
type: str = None
first: bool = False
@dataclass
class QlProperty:
singular: str
type: str
tablename: str
tableparams: List[QlParam]
plural: str = None
params: List[QlParam] = field(default_factory=list)
first: bool = False
local_var: str = "x"
def __post_init__(self):
if self.params:
self.params[0].first = True
while self.local_var in (p.param for p in self.params):
self.local_var += "_"
assert self.tableparams
if self.type_is_class:
self.tableparams = [x if x != "result" else self.local_var for x in self.tableparams]
self.tableparams = [QlParam(x) for x in self.tableparams]
self.tableparams[0].first = True
@property
def indefinite_article(self):
if self.plural:
return "An" if self.singular[0] in "AEIO" else "A"
@property
def type_is_class(self):
return self.type[0].isupper()
@dataclass
class QlClass:
template: ClassVar = 'ql_class'
name: str
bases: List[str]
final: bool
properties: List[QlProperty]
dir: pathlib.Path
imports: List[str] = field(default_factory=list)
def __post_init__(self):
self.bases = sorted(self.bases)
if self.properties:
self.properties[0].first = True
@property
def db_id(self):
return "@" + inflection.underscore(self.name)
@property
def root(self):
return not self.bases
@property
def path(self):
return self.dir / self.name
@dataclass
class QlStub:
template: ClassVar = 'ql_stub'
name: str
base_import: str
@dataclass
class QlImportList:
template: ClassVar = 'ql_imports'
imports: List[str] = field(default_factory=list)
def get_ql_property(cls: schema.Class, prop: schema.Property):
if prop.is_single:
return QlProperty(
return ql.QlProperty(
singular=inflection.camelize(prop.name),
type=prop.type,
tablename=inflection.tableize(cls.name),
tableparams=["this"] + ["result" if p is prop else "_" for p in cls.properties if p.is_single],
)
elif prop.is_optional:
return QlProperty(
return ql.QlProperty(
singular=inflection.camelize(prop.name),
type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "result"],
)
elif prop.is_repeated:
return QlProperty(
return ql.QlProperty(
singular=inflection.singularize(inflection.camelize(prop.name)),
plural=inflection.pluralize(inflection.camelize(prop.name)),
type=prop.type,
tablename=inflection.tableize(f"{cls.name}_{prop.name}"),
tableparams=["this", "index", "result"],
params=[QlParam("index", type="int")],
params=[ql.QlParam("index", type="int")],
)
def get_ql_class(cls: schema.Class):
return QlClass(
return ql.QlClass(
name=cls.name,
bases=cls.bases,
final=not cls.derived,
@ -137,7 +51,7 @@ def get_import(file):
return str(stem).replace("/", ".")
def get_types_used_by(cls: QlClass):
def get_types_used_by(cls: ql.QlClass):
for b in cls.bases:
yield b
for p in cls.properties:
@ -146,7 +60,7 @@ def get_types_used_by(cls: QlClass):
yield param.type
def get_classes_used_by(cls: QlClass):
def get_classes_used_by(cls: ql.QlClass):
return sorted(set(t for t in get_types_used_by(cls) if t[0].isupper()))
@ -164,34 +78,32 @@ def format(codeql, files):
def generate(opts, renderer):
input = opts.schema.resolve()
out = opts.ql_output.resolve()
stub_out = opts.ql_stub_output.resolve()
input = opts.schema
out = opts.ql_output
stub_out = opts.ql_stub_output
existing = {q for q in out.rglob("*.qll")}
existing |= {q for q in stub_out.rglob("*.qll") if is_generated(q)}
with open(input) as src:
data = schema.load(src)
data = schema.load(input)
classes = [get_ql_class(cls) for cls in data.classes.values()]
classes = [get_ql_class(cls) for cls in data.classes]
imports = {}
for c in classes:
imports[c.name] = get_import(stub_out / c.path)
for c in classes:
assert not c.final or c.bases, c.name
qll = (out / c.path).with_suffix(".qll")
c.imports = [imports[t] for t in get_classes_used_by(c)]
renderer.render(c, qll)
stub_file = (stub_out / c.path).with_suffix(".qll")
if not stub_file.is_file() or is_generated(stub_file):
stub = QlStub(name=c.name, base_import=get_import(qll))
stub = ql.QlStub(name=c.name, base_import=get_import(qll))
renderer.render(stub, stub_file)
# for example path/to/syntax/generated -> path/to/syntax.qll
include_file = stub_out.with_suffix(".qll")
all_imports = QlImportList(v for _, v in sorted(imports.items()))
all_imports = ql.QlImportList([v for _, v in sorted(imports.items())])
renderer.render(all_imports, include_file)
renderer.cleanup(existing)

Просмотреть файл

@ -1,3 +1,4 @@
pystache
pyyaml
inflection
pytest

Просмотреть файл

@ -0,0 +1,328 @@
import pathlib
import sys
from swift.codegen import dbschemegen
from swift.codegen.lib import dbscheme, paths
from swift.codegen.test.utils import *
def generate(opts, renderer):
(out, data), = run_generation(dbschemegen.generate, opts, renderer).items()
assert out is opts.dbscheme
return data
def test_empty(opts, input, renderer):
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[],
)
def test_includes(opts, input, renderer):
includes = ["foo", "bar"]
input.includes = includes
for i in includes:
write(opts.schema.parent / i, i + " data")
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[
dbscheme.DbSchemeInclude(
src=schema_dir / i,
data=i + " data",
) for i in includes
],
declarations=[],
)
def test_empty_final_class(opts, input, renderer):
input.classes = [
schema.Class("Object"),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
]
)
],
)
def test_final_class_with_single_scalar_field(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.SingleProperty("foo", "bar"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
dbscheme.DbColumn('foo', 'bar'),
]
)
],
)
def test_final_class_with_single_class_field(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.SingleProperty("foo", "Bar"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
dbscheme.DbColumn('foo', '@bar'),
]
)
],
)
def test_final_class_with_optional_field(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.OptionalProperty("foo", "bar"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
]
),
dbscheme.DbTable(
name="object_foos",
keyset=dbscheme.DbKeySet(["id"]),
columns=[
dbscheme.DbColumn('id', '@object'),
dbscheme.DbColumn('foo', 'bar'),
]
),
],
)
def test_final_class_with_repeated_field(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.RepeatedProperty("foo", "bar"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
]
),
dbscheme.DbTable(
name="object_foos",
keyset=dbscheme.DbKeySet(["id", "index"]),
columns=[
dbscheme.DbColumn('id', '@object'),
dbscheme.DbColumn('index', 'int'),
dbscheme.DbColumn('foo', 'bar'),
]
),
],
)
def test_final_class_with_more_fields(opts, input, renderer):
input.classes = [
schema.Class("Object", properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.OptionalProperty("three", "z"),
schema.RepeatedProperty("four", "w"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbTable(
name="objects",
columns=[
dbscheme.DbColumn('id', '@object', binding=True),
dbscheme.DbColumn('one', 'x'),
dbscheme.DbColumn('two', 'y'),
]
),
dbscheme.DbTable(
name="object_threes",
keyset=dbscheme.DbKeySet(["id"]),
columns=[
dbscheme.DbColumn('id', '@object'),
dbscheme.DbColumn('three', 'z'),
]
),
dbscheme.DbTable(
name="object_fours",
keyset=dbscheme.DbKeySet(["id", "index"]),
columns=[
dbscheme.DbColumn('id', '@object'),
dbscheme.DbColumn('index', 'int'),
dbscheme.DbColumn('four', 'w'),
]
),
],
)
def test_empty_class_with_derived(opts, input, renderer):
input.classes = [
schema.Class(
name="Base",
derived={"Left", "Right"}),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbUnion(
lhs="@base",
rhs=["@left", "@right"],
),
],
)
def test_class_with_derived_and_single_property(opts, input, renderer):
input.classes = [
schema.Class(
name="Base",
derived={"Left", "Right"},
properties=[
schema.SingleProperty("single", "Prop"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbUnion(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.DbTable(
name="bases",
keyset=dbscheme.DbKeySet(["id"]),
columns=[
dbscheme.DbColumn('id', '@base'),
dbscheme.DbColumn('single', '@prop'),
]
)
],
)
def test_class_with_derived_and_optional_property(opts, input, renderer):
input.classes = [
schema.Class(
name="Base",
derived={"Left", "Right"},
properties=[
schema.OptionalProperty("opt", "Prop"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbUnion(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.DbTable(
name="base_opts",
keyset=dbscheme.DbKeySet(["id"]),
columns=[
dbscheme.DbColumn('id', '@base'),
dbscheme.DbColumn('opt', '@prop'),
]
)
],
)
def test_class_with_derived_and_repeated_property(opts, input, renderer):
input.classes = [
schema.Class(
name="Base",
derived={"Left", "Right"},
properties=[
schema.RepeatedProperty("rep", "Prop"),
]),
]
assert generate(opts, renderer) == dbscheme.DbScheme(
src=schema_file,
includes=[],
declarations=[
dbscheme.DbUnion(
lhs="@base",
rhs=["@left", "@right"],
),
dbscheme.DbTable(
name="base_reps",
keyset=dbscheme.DbKeySet(["id", "index"]),
columns=[
dbscheme.DbColumn('id', '@base'),
dbscheme.DbColumn('index', 'int'),
dbscheme.DbColumn('rep', '@prop'),
]
)
],
)
def test_dbcolumn_name():
assert dbscheme.DbColumn("foo", "some_type").name == "foo"
@pytest.mark.parametrize("keyword", dbscheme.dbscheme_keywords)
def test_dbcolumn_keyword_name(keyword):
assert dbscheme.DbColumn(keyword, "some_type").name == keyword + "_"
@pytest.mark.parametrize("type,binding,lhstype,rhstype", [
("builtin_type", False, "builtin_type", "builtin_type ref"),
("builtin_type", True, "builtin_type", "builtin_type ref"),
("@at_type", False, "int", "@at_type ref"),
("@at_type", True, "unique int", "@at_type"),
])
def test_dbcolumn_types(type, binding, lhstype, rhstype):
col = dbscheme.DbColumn("foo", type, binding)
assert col.lhstype == lhstype
assert col.rhstype == rhstype
if __name__ == '__main__':
sys.exit(pytest.main())

Просмотреть файл

@ -0,0 +1,199 @@
import subprocess
import sys
import mock
from swift.codegen import qlgen
from swift.codegen.lib import ql, paths
from swift.codegen.test.utils import *
@pytest.fixture(autouse=True)
def run_mock():
with mock.patch("subprocess.run") as ret:
yield ret
stub_path = lambda: paths.swift_dir / "ql/lib/stub/path"
ql_output_path = lambda: paths.swift_dir / "ql/lib/other/path"
import_file = lambda: stub_path().with_suffix(".qll")
stub_import_prefix = "stub.path."
gen_import_prefix = "other.path."
index_param = ql.QlParam("index", "int")
def generate(opts, renderer, written=None):
opts.ql_stub_output = stub_path()
opts.ql_output = ql_output_path()
renderer.written = written or []
return run_generation(qlgen.generate, opts, renderer)
def test_empty(opts, input, renderer):
assert generate(opts, renderer) == {
import_file(): ql.QlImportList()
}
def test_one_empty_class(opts, input, renderer):
input.classes = [
schema.Class("A")
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "A"]),
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"),
ql_output_path() / "A.qll": ql.QlClass(name="A", final=True),
}
def test_hierarchy(opts, input, renderer):
input.classes = [
schema.Class("D", bases={"B", "C"}),
schema.Class("C", bases={"A"}, derived={"D"}),
schema.Class("B", bases={"A"}, derived={"D"}),
schema.Class("A", derived={"B", "C"}),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in "ABCD"]),
stub_path() / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "A"),
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"),
stub_path() / "C.qll": ql.QlStub(name="C", base_import=gen_import_prefix + "C"),
stub_path() / "D.qll": ql.QlStub(name="D", base_import=gen_import_prefix + "D"),
ql_output_path() / "A.qll": ql.QlClass(name="A"),
ql_output_path() / "B.qll": ql.QlClass(name="B", bases=["A"], imports=[stub_import_prefix + "A"]),
ql_output_path() / "C.qll": ql.QlClass(name="C", bases=["A"], imports=[stub_import_prefix + "A"]),
ql_output_path() / "D.qll": ql.QlClass(name="D", final=True, bases=["B", "C"],
imports=[stub_import_prefix + cls for cls in "BC"]),
}
def test_single_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "bar")]),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", type="bar", tablename="my_objects", tableparams=["this", "result"]),
])
}
def test_single_properties(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[
schema.SingleProperty("one", "x"),
schema.SingleProperty("two", "y"),
schema.SingleProperty("three", "z"),
]),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
ql.QlProperty(singular="One", type="x", tablename="my_objects", tableparams=["this", "result", "_", "_"]),
ql.QlProperty(singular="Two", type="y", tablename="my_objects", tableparams=["this", "_", "result", "_"]),
ql.QlProperty(singular="Three", type="z", tablename="my_objects", tableparams=["this", "_", "_", "result"]),
])
}
def test_optional_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.OptionalProperty("foo", "bar")]),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", type="bar", tablename="my_object_foos", tableparams=["this", "result"]),
])
}
def test_repeated_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.RepeatedProperty("foo", "bar")]),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + "MyObject"]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
ql_output_path() / "MyObject.qll": ql.QlClass(name="MyObject", final=True, properties=[
ql.QlProperty(singular="Foo", plural="Foos", type="bar", tablename="my_object_foos", params=[index_param],
tableparams=["this", "index", "result"]),
])
}
def test_single_class_property(opts, input, renderer):
input.classes = [
schema.Class("MyObject", properties=[schema.SingleProperty("foo", "Bar")]),
schema.Class("Bar"),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([stub_import_prefix + cls for cls in ("Bar", "MyObject")]),
stub_path() / "MyObject.qll": ql.QlStub(name="MyObject", base_import=gen_import_prefix + "MyObject"),
stub_path() / "Bar.qll": ql.QlStub(name="Bar", base_import=gen_import_prefix + "Bar"),
ql_output_path() / "MyObject.qll": ql.QlClass(
name="MyObject", final=True, imports=[stub_import_prefix + "Bar"], properties=[
ql.QlProperty(singular="Foo", type="Bar", tablename="my_objects", tableparams=["this", "result"]),
],
),
ql_output_path() / "Bar.qll": ql.QlClass(name="Bar", final=True)
}
def test_class_dir(opts, input, renderer):
dir = pathlib.Path("another/rel/path")
input.classes = [
schema.Class("A", derived={"B"}, dir=dir),
schema.Class("B", bases={"A"}),
]
assert generate(opts, renderer) == {
import_file(): ql.QlImportList([
stub_import_prefix + "another.rel.path.A",
stub_import_prefix + "B",
]),
stub_path() / dir / "A.qll": ql.QlStub(name="A", base_import=gen_import_prefix + "another.rel.path.A"),
stub_path() / "B.qll": ql.QlStub(name="B", base_import=gen_import_prefix + "B"),
ql_output_path() / dir / "A.qll": ql.QlClass(name="A", dir=dir),
ql_output_path() / "B.qll": ql.QlClass(name="B", final=True, bases=["A"],
imports=[stub_import_prefix + "another.rel.path.A"])
}
def test_format(opts, input, renderer, run_mock):
opts.codeql_binary = "my_fake_codeql"
run_mock.return_value.stderr = "some\nlines\n"
generate(opts, renderer, written=["foo", "bar"])
assert run_mock.mock_calls == [
mock.call(["my_fake_codeql", "query", "format", "--in-place", "--", "foo", "bar"],
check=True, stderr=subprocess.PIPE, text=True),
]
def test_empty_cleanup(opts, input, renderer):
generate(opts, renderer)
assert renderer.mock_calls[-1] == mock.call.cleanup(set())
def test_empty_cleanup(opts, input, renderer, tmp_path):
opts.ql_output = tmp_path / "gen"
opts.ql_stub_output = tmp_path / "stub"
renderer.written = []
ql_a = opts.ql_output / "A.qll"
ql_b = opts.ql_output / "B.qll"
stub_a = opts.ql_stub_output / "A.qll"
stub_b = opts.ql_stub_output / "B.qll"
write(ql_a)
write(ql_b)
write(stub_a, "// generated\nfoo\n")
write(stub_b, "bar\n")
run_generation(qlgen.generate, opts, renderer)
assert renderer.mock_calls[-1] == mock.call.cleanup({ql_a, ql_b, stub_a})
if __name__ == '__main__':
sys.exit(pytest.main())

Просмотреть файл

@ -0,0 +1,79 @@
import sys
from unittest import mock
import pytest
from swift.codegen.lib import paths
from swift.codegen.lib import render
@pytest.fixture
def pystache_renderer_cls():
with mock.patch("pystache.Renderer") as ret:
yield ret
@pytest.fixture
def pystache_renderer(pystache_renderer_cls):
ret = mock.Mock()
pystache_renderer_cls.side_effect = (ret,)
return ret
@pytest.fixture
def sut(pystache_renderer):
return render.Renderer()
def test_constructor(pystache_renderer_cls, sut):
pystache_init, = pystache_renderer_cls.mock_calls
assert set(pystache_init.kwargs) == {'search_dirs', 'escape'}
assert pystache_init.kwargs['search_dirs'] == str(paths.templates_dir)
an_object = object()
assert pystache_init.kwargs['escape'](an_object) is an_object
assert sut.written == set()
def test_render(pystache_renderer, sut):
data = mock.Mock()
output = mock.Mock()
with mock.patch("builtins.open", mock.mock_open()) as output_stream:
sut.render(data, output)
assert pystache_renderer.mock_calls == [
mock.call.render_name(data.template, data, generator=paths.exe_file),
], pystache_renderer.mock_calls
assert output_stream.mock_calls == [
mock.call(output, 'w'),
mock.call().__enter__(),
mock.call().write(pystache_renderer.render_name.return_value),
mock.call().__exit__(None, None, None),
]
assert sut.written == {output}
def test_written(sut):
data = [mock.Mock() for _ in range(4)]
output = [mock.Mock() for _ in data]
with mock.patch("builtins.open", mock.mock_open()) as output_stream:
for d, o in zip(data, output):
sut.render(d, o)
assert sut.written == set(output)
def test_cleanup(sut):
data = [mock.Mock() for _ in range(4)]
output = [mock.Mock() for _ in data]
with mock.patch("builtins.open", mock.mock_open()) as output_stream:
for d, o in zip(data, output):
sut.render(d, o)
expected_erased = [mock.Mock() for _ in range(3)]
existing = set(expected_erased + output[2:])
sut.cleanup(existing)
for f in expected_erased:
assert f.mock_calls == [mock.call.unlink(missing_ok=True)]
for f in output:
assert f.unlink.mock_calls == []
if __name__ == '__main__':
sys.exit(pytest.main())

Просмотреть файл

@ -0,0 +1,158 @@
import io
import pathlib
import sys
import mock
import pytest
import swift.codegen.lib.schema as schema
from swift.codegen.test.utils import *
root_name = schema.root_class_name
@pytest.fixture
def load(tmp_path):
file = tmp_path / "schema.yml"
def ret(yml):
write(file, yml)
return schema.load(file)
return ret
def test_empty_schema(load):
ret = load("{}")
assert ret.classes == [schema.Class(root_name)]
assert ret.includes == set()
def test_one_empty_class(load):
ret = load("""
MyClass: {}
""")
assert ret.classes == [
schema.Class(root_name, derived={'MyClass'}),
schema.Class('MyClass', bases={root_name}),
]
def test_two_empty_classes(load):
ret = load("""
MyClass1: {}
MyClass2: {}
""")
assert ret.classes == [
schema.Class(root_name, derived={'MyClass1', 'MyClass2'}),
schema.Class('MyClass1', bases={root_name}),
schema.Class('MyClass2', bases={root_name}),
]
def test_two_empty_chained_classes(load):
ret = load("""
MyClass1: {}
MyClass2:
_extends: MyClass1
""")
assert ret.classes == [
schema.Class(root_name, derived={'MyClass1'}),
schema.Class('MyClass1', bases={root_name}, derived={'MyClass2'}),
schema.Class('MyClass2', bases={'MyClass1'}),
]
def test_empty_classes_diamond(load):
ret = load("""
A: {}
B: {}
C:
_extends:
- A
- B
""")
assert ret.classes == [
schema.Class(root_name, derived={'A', 'B'}),
schema.Class('A', bases={root_name}, derived={'C'}),
schema.Class('B', bases={root_name}, derived={'C'}),
schema.Class('C', bases={'A', 'B'}),
]
def test_dir(load):
ret = load("""
A:
_dir: other/dir
""")
assert ret.classes == [
schema.Class(root_name, derived={'A'}),
schema.Class('A', bases={root_name}, dir=pathlib.Path("other/dir")),
]
def test_directory_filter(load):
ret = load("""
_directories:
first/dir: '[xy]'
second/dir: foo$
third/dir: bar$
Afoo: {}
Bbar: {}
Abar: {}
Bfoo: {}
Ax: {}
Ay: {}
A: {}
""")
assert ret.classes == [
schema.Class(root_name, derived={'Afoo', 'Bbar', 'Abar', 'Bfoo', 'Ax', 'Ay', 'A'}),
schema.Class('Afoo', bases={root_name}, dir=pathlib.Path("second/dir")),
schema.Class('Bbar', bases={root_name}, dir=pathlib.Path("third/dir")),
schema.Class('Abar', bases={root_name}, dir=pathlib.Path("third/dir")),
schema.Class('Bfoo', bases={root_name}, dir=pathlib.Path("second/dir")),
schema.Class('Ax', bases={root_name}, dir=pathlib.Path("first/dir")),
schema.Class('Ay', bases={root_name}, dir=pathlib.Path("first/dir")),
schema.Class('A', bases={root_name}, dir=pathlib.Path()),
]
def test_directory_filter_override(load):
ret = load("""
_directories:
one/dir: ^A$
A:
_dir: other/dir
""")
assert ret.classes == [
schema.Class(root_name, derived={'A'}),
schema.Class('A', bases={root_name}, dir=pathlib.Path("other/dir")),
]
def test_lowercase_rejected(load):
with pytest.raises(AssertionError):
load("aLowercase: {}")
def test_digit_rejected(load):
with pytest.raises(AssertionError):
load("1digit: {}")
def test_properties(load):
ret = load("""
A:
one: string
two: int?
three: bool*
""")
assert ret.classes == [
schema.Class(root_name, derived={'A'}),
schema.Class('A', bases={root_name}, properties=[
schema.SingleProperty('one', 'string'),
schema.OptionalProperty('two', 'int'),
schema.RepeatedProperty('three', 'bool'),
]),
]
if __name__ == '__main__':
sys.exit(pytest.main())

Просмотреть файл

@ -0,0 +1,50 @@
import pathlib
from unittest import mock
import pytest
from swift.codegen.lib import render, schema
schema_dir = pathlib.Path("a", "dir")
schema_file = schema_dir / "schema.yml"
def write(out, contents=""):
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w") as out:
out.write(contents)
@pytest.fixture
def renderer():
return mock.Mock(spec=render.Renderer())
@pytest.fixture
def opts():
return mock.MagicMock()
@pytest.fixture(autouse=True)
def override_paths(tmp_path):
with mock.patch("swift.codegen.lib.paths.swift_dir", tmp_path):
yield
@pytest.fixture
def input(opts, tmp_path):
opts.schema = tmp_path / schema_file
with mock.patch("swift.codegen.lib.schema.load") as load_mock:
load_mock.return_value = schema.Schema([])
yield load_mock.return_value
assert load_mock.mock_calls == [
mock.call(opts.schema)
], load_mock.mock_calls
def run_generation(generate, opts, renderer):
output = {}
renderer.render.side_effect = lambda data, out: output.__setitem__(out, data)
generate(opts, renderer)
return output

Просмотреть файл

@ -15,6 +15,16 @@ answer_to_life_the_universe_and_everything(
// from codegen/schema.yml
@element =
@argument
| @file
| @generic_context
| @iterable_decl_context
| @locatable
| @location
| @type
;
files(
unique int id: @file,
string name: string ref
@ -1886,13 +1896,3 @@ integer_literal_exprs(
unique int id: @integer_literal_expr,
string string_value: string ref
);
@element =
@argument
| @file
| @generic_context
| @iterable_decl_context
| @locatable
| @location
| @type
;