* Parallelize dependency graph

* Use GCP API to get table schema when not using cloud function

* Reuse GCP credentials

* Update dependency tests

* Remove print
This commit is contained in:
Anna Scholtz 2024-10-30 07:52:05 -07:00 коммит произвёл GitHub
Родитель 5ce6152d0b
Коммит 2add865249
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
4 изменённых файлов: 102 добавлений и 43 удалений

Просмотреть файл

@ -1782,7 +1782,9 @@ def update(
if str(query_file)
not in ConfigLoader.get("schema", "deploy", "skip", fallback=[])
]
dependency_graph = get_dependency_graph([sql_dir], without_views=True)
dependency_graph = get_dependency_graph(
[sql_dir], without_views=True, parallelism=parallelism
)
manager = multiprocessing.Manager()
tmp_tables = manager.dict({})
@ -1884,7 +1886,7 @@ def _update_query_schema_with_downstream(
# create temporary table with updated schema
if identifier not in tmp_tables:
schema = Schema.from_schema_file(query_file.parent / SCHEMA_FILE)
schema.deploy(tmp_identifier)
schema.deploy(tmp_identifier, credentials)
tmp_tables[identifier] = tmp_identifier
# get downstream dependencies that will be updated in the next iteration
@ -1966,7 +1968,7 @@ def _update_query_schema(
f"{parent_project}.{tmp_dataset}.{parent_table}_"
+ random_str(12)
)
parent_schema.deploy(tmp_parent_identifier)
parent_schema.deploy(tmp_parent_identifier, credentials=credentials)
tmp_tables[parent_identifier] = tmp_parent_identifier
if existing_schema_path.is_file():
@ -1980,7 +1982,7 @@ def _update_query_schema(
tmp_identifier = (
f"{project_name}.{tmp_dataset}.{table_name}_{random_str(12)}"
)
existing_schema.deploy(tmp_identifier)
existing_schema.deploy(tmp_identifier, credentials=credentials)
tmp_tables[f"{project_name}.{dataset_name}.{table_name}"] = (
tmp_identifier
)

Просмотреть файл

@ -2,8 +2,10 @@
import re
import sys
from functools import partial
from glob import glob
from itertools import groupby
from multiprocessing.pool import Pool
from pathlib import Path
from subprocess import CalledProcessError
from typing import Dict, Iterator, List, Tuple
@ -130,9 +132,24 @@ def extract_table_references_without_views(path: Path) -> Iterator[str]:
yield ".".join(parts)
def _extract_table_references(without_views, path):
try:
if without_views:
return path, list(extract_table_references_without_views(path))
else:
sql = render(path.name, template_folder=path.parent)
return path, extract_table_references(sql)
except CalledProcessError as e:
raise click.ClickException(f"failed to import jnius: {e}")
except ImportError as e:
raise click.ClickException(*e.args)
except ValueError as e:
raise ValueError(f"Failed to parse {path}: {e}", file=sys.stderr)
def _get_references(
paths: Tuple[str, ...], without_views: bool = False
) -> Iterator[Tuple[Path, List[str]]]:
paths: Tuple[str, ...], without_views: bool = False, parallelism: int = 8
) -> List[Tuple[Path, List[str]]]:
file_paths = {
path
for parent in map(Path, paths or ["sql"])
@ -144,29 +161,37 @@ def _get_references(
if not path.name.endswith(".template.sql") # skip templates
}
fail = False
for path in sorted(file_paths):
if parallelism <= 1:
try:
if without_views:
yield path, list(extract_table_references_without_views(path))
else:
sql = render(path.name, template_folder=path.parent)
yield path, extract_table_references(sql)
except CalledProcessError as e:
raise click.ClickException(f"failed to import jnius: {e}")
except ImportError as e:
raise click.ClickException(*e.args)
return [
_extract_table_references(without_views, file_path)
for file_path in sorted(file_paths)
]
except ValueError as e:
fail = True
print(f"Failed to parse {path}: {e}", file=sys.stderr)
print(f"Failed to parse file: {e}", file=sys.stderr)
else:
with Pool(parallelism) as pool:
try:
result = pool.map(
partial(_extract_table_references, without_views), file_paths
)
return result
except ValueError as e:
fail = True
print(f"Failed to parse file: {e}", file=sys.stderr)
if fail:
raise click.ClickException("Some paths could not be analyzed")
return []
def get_dependency_graph(
paths: Tuple[str, ...], without_views: bool = False
paths: Tuple[str, ...], without_views: bool = False, parallelism: int = 8
) -> Dict[str, List[str]]:
"""Return the query dependency graph."""
refs = _get_references(paths, without_views=without_views)
refs = _get_references(paths, without_views=without_views, parallelism=parallelism)
dependency_graph = {}
for ref in refs:
@ -198,9 +223,16 @@ def dependency():
is_flag=True,
help="recursively resolve view references to underlying tables",
)
def show(paths: Tuple[str, ...], without_views: bool):
@click.option(
"--parallelism",
"-p",
default=8,
type=int,
help="Number of threads for parallel processing",
)
def show(paths: Tuple[str, ...], without_views: bool, parallelism: int):
"""Show table references in sql files."""
for path, table_references in _get_references(paths, without_views):
for path, table_references in _get_references(paths, without_views, parallelism):
if table_references:
for table in table_references:
print(f"{path}: {table}")
@ -223,9 +255,18 @@ def show(paths: Tuple[str, ...], without_views: bool):
is_flag=True,
help="Skip files with existing references rather than failing",
)
def record(paths: Tuple[str, ...], skip_existing):
@click.option(
"--parallelism",
"-p",
default=8,
type=int,
help="Number of threads for parallel processing",
)
def record(paths: Tuple[str, ...], skip_existing, parallelism):
"""Record table references in metadata."""
for parent, group in groupby(_get_references(paths), lambda e: e[0].parent):
for parent, group in groupby(
_get_references(paths, parallelism=parallelism), lambda e: e[0].parent
):
references = {
path.name: table_references
for path, table_references in group

Просмотреть файл

@ -60,30 +60,46 @@ class Schema:
@classmethod
def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs):
"""Get the schema for a BigQuery table."""
query = f"SELECT * FROM `{project}.{dataset}.{table}`"
if partitioned_by:
query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')"
try:
return cls(
dryrun.DryRun(
os.path.join(project, dataset, table, "query.sql"),
query,
project=project,
dataset=dataset,
table=table,
*args,
**kwargs,
).get_schema()
)
if (
"use_cloud_function" not in kwargs
or kwargs["use_cloud_function"] is False
):
if "credentials" in kwargs:
client = bigquery.Client(credentials=kwargs["credentials"])
else:
client = bigquery.Client()
table = client.get_table(f"{project}.{dataset}.{table}")
return cls({"fields": [field.to_api_repr() for field in table.schema]})
else:
query = f"SELECT * FROM `{project}.{dataset}.{table}`"
if partitioned_by:
query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')"
return cls(
dryrun.DryRun(
os.path.join(project, dataset, table, "query.sql"),
query,
project=project,
dataset=dataset,
table=table,
*args,
**kwargs,
).get_schema()
)
except Exception as e:
print(f"Cannot get schema for {project}.{dataset}.{table}: {e}")
return cls({"fields": []})
def deploy(self, destination_table: str) -> bigquery.Table:
def deploy(self, destination_table: str, credentials: None) -> bigquery.Table:
"""Deploy the schema to BigQuery named after destination_table."""
client = bigquery.Client()
if credentials:
client = bigquery.Client(credentials=credentials)
else:
client = bigquery.Client()
tmp_schema_file = NamedTemporaryFile()
self.to_json_file(Path(tmp_schema_file.name))
bigquery_schema = client.schema_from_json(tmp_schema_file.name)

Просмотреть файл

@ -22,7 +22,7 @@ class TestDependency:
with open("foo.sql", "w") as f:
f.write("SELECT 1 FROM test")
result = runner.invoke(dependency_show, ["foo.sql"])
result = runner.invoke(dependency_show, ["foo.sql", "--parallelism=1"])
assert "foo.sql: test\n" == result.output
assert result.exit_code == 0
@ -32,6 +32,6 @@ class TestDependency:
with open("test/bar.sql", "w") as f:
f.write("SELECT 1 FROM test_bar")
result = runner.invoke(dependency_show, ["test"])
result = runner.invoke(dependency_show, ["test", "--parallelism=1"])
assert "test/bar.sql: test_bar\ntest/foo.sql: test_foo\n" == result.output
assert result.exit_code == 0