Refactor dry run and update integration tests
This commit is contained in:
Родитель
58b3cf31b7
Коммит
47ff0d88f3
|
@ -5,7 +5,7 @@ import os
|
|||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
from bigquery_etl.dryrun import sql_file_valid
|
||||
from bigquery_etl.dryrun import DryRun
|
||||
from bigquery_etl.parse_udf import read_udf_dirs, persistent_udf_as_temp
|
||||
from bigquery_etl.util import standard_args
|
||||
|
||||
|
@ -77,7 +77,7 @@ def main():
|
|||
tmp_example_file = tmp_dir / file
|
||||
tmp_example_file.write_text(dry_run_sql)
|
||||
|
||||
sql_file_valid(str(tmp_example_file))
|
||||
DryRun(str(tmp_example_file)).is_valid()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -18,10 +18,6 @@ import glob
|
|||
import json
|
||||
import sys
|
||||
|
||||
DRY_RUN_URL = (
|
||||
"https://us-central1-moz-fx-data-shared-prod.cloudfunctions.net/bigquery-etl-dryrun"
|
||||
)
|
||||
|
||||
SKIP = {
|
||||
# Access Denied
|
||||
"sql/activity_stream/impression_stats_flat/view.sql",
|
||||
|
@ -111,77 +107,98 @@ SKIP = {
|
|||
}
|
||||
|
||||
|
||||
def get_referenced_tables(sqlfile, response=None):
|
||||
"""Return referenced tables by dry running the SQL file."""
|
||||
if response is None:
|
||||
response = dry_run_sql_file(sqlfile)
|
||||
class DryRun:
|
||||
"""Dry run SQL files."""
|
||||
|
||||
if not sql_file_valid(sqlfile, response):
|
||||
raise Exception(f"Error when dry running SQL file {sqlfile}")
|
||||
DRY_RUN_URL = (
|
||||
"https://us-central1-moz-fx-data-shared-prod.cloudfunctions.net/"
|
||||
"bigquery-etl-dryrun"
|
||||
)
|
||||
|
||||
if response and response["valid"] and "referencedTables" in response:
|
||||
return response["referencedTables"]
|
||||
def __init__(self, sqlfile):
|
||||
"""Instantiate DryRun class."""
|
||||
self.sqlfile = sqlfile
|
||||
self.dry_run_result = None
|
||||
|
||||
return []
|
||||
def _execute(self):
|
||||
"""Dry run the provided SQL file."""
|
||||
if self.dry_run_result:
|
||||
return self.dry_run_result
|
||||
|
||||
sql = open(self.sqlfile).read()
|
||||
|
||||
def sql_file_valid(sqlfile, response=None):
|
||||
"""Dry run the provided SQL file and check if valid."""
|
||||
if response is None:
|
||||
response = dry_run_sql_file(sqlfile)
|
||||
|
||||
if response is None:
|
||||
return False
|
||||
|
||||
if "errors" in response and len(response["errors"]) == 1:
|
||||
error = response["errors"][0]
|
||||
else:
|
||||
error = None
|
||||
|
||||
if response["valid"]:
|
||||
print(f"{sqlfile:59} OK")
|
||||
elif (
|
||||
error
|
||||
and error.get("code", None) in [400, 403]
|
||||
and "does not have bigquery.tables.create permission for dataset"
|
||||
in error.get("message", "")
|
||||
):
|
||||
# We want the dryrun service to only have read permissions, so
|
||||
# we expect CREATE VIEW and CREATE TABLE to throw specific
|
||||
# exceptions.
|
||||
print(f"{sqlfile:59} OK")
|
||||
else:
|
||||
print(f"{sqlfile:59} ERROR\n", response["errors"])
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def dry_run_sql_file(sqlfile):
|
||||
"""Dry run the provided SQL file."""
|
||||
sql = open(sqlfile).read()
|
||||
|
||||
try:
|
||||
r = urlopen(
|
||||
Request(
|
||||
DRY_RUN_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(
|
||||
{"dataset": basename(dirname(dirname(sqlfile))), "query": sql}
|
||||
).encode("utf8"),
|
||||
method="POST",
|
||||
try:
|
||||
r = urlopen(
|
||||
Request(
|
||||
self.DRY_RUN_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(
|
||||
{
|
||||
"dataset": basename(dirname(dirname(self.sqlfile))),
|
||||
"query": sql,
|
||||
}
|
||||
).encode("utf8"),
|
||||
method="POST",
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{sqlfile:59} ERROR\n", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"{self.sqlfile:59} ERROR\n", e)
|
||||
return None
|
||||
|
||||
return json.load(r)
|
||||
self.dry_run_result = json.load(r)
|
||||
return self.dry_run_result
|
||||
|
||||
def get_referenced_tables(self):
|
||||
"""Return referenced tables by dry running the SQL file."""
|
||||
response = self._execute()
|
||||
|
||||
if not self.is_valid():
|
||||
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
|
||||
|
||||
if response and response["valid"] and "referencedTables" in response:
|
||||
return response["referencedTables"]
|
||||
|
||||
return []
|
||||
|
||||
def is_valid(self):
|
||||
"""Dry run the provided SQL file and check if valid."""
|
||||
response = self._execute()
|
||||
|
||||
if response is None:
|
||||
return False
|
||||
|
||||
if "errors" in response and len(response["errors"]) == 1:
|
||||
error = response["errors"][0]
|
||||
else:
|
||||
error = None
|
||||
|
||||
if response["valid"]:
|
||||
print(f"{self.sqlfile:59} OK")
|
||||
elif (
|
||||
error
|
||||
and error.get("code", None) in [400, 403]
|
||||
and "does not have bigquery.tables.create permission for dataset"
|
||||
in error.get("message", "")
|
||||
):
|
||||
# We want the dryrun service to only have read permissions, so
|
||||
# we expect CREATE VIEW and CREATE TABLE to throw specific
|
||||
# exceptions.
|
||||
print(f"{self.sqlfile:59} OK")
|
||||
else:
|
||||
print(f"{self.sqlfile:59} ERROR\n", response["errors"])
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Dry run all SQL files in the sql/ directory."""
|
||||
sql_files = [f for f in glob.glob("sql/**/*.sql", recursive=True) if f not in SKIP]
|
||||
|
||||
def sql_file_valid(sqlfile):
|
||||
"""Dry run SQL files."""
|
||||
return DryRun(sqlfile).is_valid()
|
||||
|
||||
with ThreadPool(8) as p:
|
||||
result = p.map(sql_file_valid, sql_files, chunksize=1)
|
||||
if all(result):
|
||||
|
|
|
@ -10,7 +10,7 @@ import logging
|
|||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
from bigquery_etl import dryrun
|
||||
from bigquery_etl.dryrun import DryRun
|
||||
from bigquery_etl.metadata.parse_metadata import Metadata
|
||||
from bigquery_etl.query_scheduling.utils import (
|
||||
is_date_string,
|
||||
|
@ -293,7 +293,7 @@ class Task:
|
|||
query_files = glob.glob(self.sql_file_path + "/*.sql")
|
||||
|
||||
for query_file in query_files:
|
||||
referenced_tables = dryrun.get_referenced_tables(query_file)
|
||||
referenced_tables = DryRun(query_file).get_referenced_tables()
|
||||
|
||||
if len(referenced_tables) >= 50:
|
||||
logging.warn(
|
||||
|
|
|
@ -3,13 +3,18 @@ from jinja2 import Environment, PackageLoader
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from bigquery_etl.query_scheduling.dag_collection import DagCollection
|
||||
from bigquery_etl.query_scheduling.dag import InvalidDag, DagParseException
|
||||
from bigquery_etl.query_scheduling.task import Task
|
||||
from bigquery_etl.metadata.parse_metadata import Metadata
|
||||
from bigquery_etl.dryrun import DryRun
|
||||
|
||||
TEST_DIR = Path(__file__).parent.parent
|
||||
TEST_DRY_RUN_URL = (
|
||||
"https://us-central1-bigquery-etl-integration-test.cloudfunctions.net/dryrun"
|
||||
)
|
||||
|
||||
|
||||
class TestDagCollection:
|
||||
|
@ -207,6 +212,7 @@ class TestDagCollection:
|
|||
).with_tasks(tasks)
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_to_airflow(self, tmp_path):
|
||||
query_file = (
|
||||
TEST_DIR
|
||||
|
@ -255,6 +261,7 @@ class TestDagCollection:
|
|||
assert result == expected
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_to_airflow_with_dependencies(
|
||||
self, tmp_path, project_id, temporary_dataset, bigquery_client
|
||||
):
|
||||
|
@ -378,6 +385,7 @@ class TestDagCollection:
|
|||
assert dag_external_dependency == expected_dag_external_dependency
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_public_json_dag_to_airflow(self, tmp_path):
|
||||
query_file = (
|
||||
TEST_DIR
|
||||
|
@ -422,6 +430,7 @@ class TestDagCollection:
|
|||
assert result == expected_dag
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_to_airflow_duplicate_dependencies(self, tmp_path):
|
||||
query_file = (
|
||||
TEST_DIR
|
||||
|
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||
import os
|
||||
import pytest
|
||||
from typing import NewType
|
||||
from unittest import mock
|
||||
|
||||
from bigquery_etl.query_scheduling.task import (
|
||||
Task,
|
||||
|
@ -12,8 +13,12 @@ from bigquery_etl.query_scheduling.task import (
|
|||
)
|
||||
from bigquery_etl.metadata.parse_metadata import Metadata
|
||||
from bigquery_etl.query_scheduling.dag_collection import DagCollection
|
||||
from bigquery_etl.dryrun import DryRun
|
||||
|
||||
TEST_DIR = Path(__file__).parent.parent
|
||||
TEST_DRY_RUN_URL = (
|
||||
"https://us-central1-bigquery-etl-integration-test.cloudfunctions.net/dryrun"
|
||||
)
|
||||
|
||||
|
||||
class TestTask:
|
||||
|
@ -342,6 +347,7 @@ class TestTask:
|
|||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_task_get_dependencies_none(self, tmp_path):
|
||||
query_file_path = tmp_path / "sql" / "test" / "query_v1"
|
||||
os.makedirs(query_file_path)
|
||||
|
@ -359,6 +365,7 @@ class TestTask:
|
|||
assert task.dependencies == []
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_task_get_multiple_dependencies(
|
||||
self, tmp_path, bigquery_client, project_id, temporary_dataset
|
||||
):
|
||||
|
@ -415,6 +422,7 @@ class TestTask:
|
|||
assert f"{temporary_dataset}__table2__v1" in tables
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_multipart_task_get_dependencies(
|
||||
self, tmp_path, bigquery_client, project_id, temporary_dataset
|
||||
):
|
||||
|
@ -475,6 +483,7 @@ class TestTask:
|
|||
assert f"{temporary_dataset}__table2__v1" in tables
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_task_get_view_dependencies(
|
||||
self, tmp_path, bigquery_client, project_id, temporary_dataset
|
||||
):
|
||||
|
@ -534,6 +543,7 @@ class TestTask:
|
|||
assert f"{temporary_dataset}__table2__v1" in tables
|
||||
|
||||
@pytest.mark.integration
|
||||
@mock.patch.object(DryRun, "DRY_RUN_URL", TEST_DRY_RUN_URL)
|
||||
def test_task_get_nested_view_dependencies(
|
||||
self, tmp_path, bigquery_client, project_id, temporary_dataset
|
||||
):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
from bigquery_etl.dryrun import get_referenced_tables, sql_file_valid, dry_run_sql_file
|
||||
from bigquery_etl.dryrun import DryRun
|
||||
|
||||
|
||||
class TestDryRun:
|
||||
|
@ -8,33 +8,38 @@ class TestDryRun:
|
|||
query_file = tmp_path / "query.sql"
|
||||
query_file.write_text("SELECT 123")
|
||||
|
||||
response = dry_run_sql_file(query_file)
|
||||
dryrun = DryRun(str(query_file))
|
||||
response = dryrun._execute()
|
||||
assert response["valid"]
|
||||
|
||||
def test_dry_run_invalid_sql_file(self, tmp_path):
|
||||
query_file = tmp_path / "query.sql"
|
||||
query_file.write_text("SELECT INVALID 123")
|
||||
|
||||
response = dry_run_sql_file(query_file)
|
||||
dryrun = DryRun(str(query_file))
|
||||
response = dryrun._execute()
|
||||
assert response["valid"] is False
|
||||
|
||||
def test_sql_file_valid(self, tmp_path):
|
||||
query_file = tmp_path / "query.sql"
|
||||
query_file.write_text("SELECT 123")
|
||||
|
||||
assert sql_file_valid(str(query_file))
|
||||
dryrun = DryRun(str(query_file))
|
||||
assert dryrun.is_valid()
|
||||
|
||||
def test_sql_file_invalid(self, tmp_path):
|
||||
query_file = tmp_path / "query.sql"
|
||||
query_file.write_text("SELECT INVALID 123")
|
||||
|
||||
assert sql_file_valid(str(query_file)) is False
|
||||
dryrun = DryRun(str(query_file))
|
||||
assert dryrun.is_valid() is False
|
||||
|
||||
def test_get_referenced_tables_empty(self, tmp_path):
|
||||
query_file = tmp_path / "query.sql"
|
||||
query_file.write_text("SELECT 123")
|
||||
|
||||
assert get_referenced_tables(str(query_file)) == []
|
||||
dryrun = DryRun(str(query_file))
|
||||
assert dryrun.get_referenced_tables() == []
|
||||
|
||||
def test_get_referenced_tables(self, tmp_path):
|
||||
os.makedirs(tmp_path / "telmetry_derived")
|
||||
|
@ -43,7 +48,8 @@ class TestDryRun:
|
|||
"SELECT * FROM telemetry_derived.clients_daily_v6 "
|
||||
"WHERE submission_date = '2020-01-01'"
|
||||
)
|
||||
response = get_referenced_tables(str(query_file))
|
||||
dryrun = DryRun(str(query_file))
|
||||
response = dryrun.get_referenced_tables()
|
||||
|
||||
assert len(response) == 1
|
||||
assert response[0]["datasetId"] == "telemetry_derived"
|
||||
|
|
Загрузка…
Ссылка в новой задаче