Bug 1695654: Add a method to fetch dataset labels from dryrun service results (#1858)

* Add a method to fetch dataset labels from dryrun service results

* Update bigquery_etl/dryrun.py

Co-authored-by: Anna Scholtz <anna@scholtzan.net>

* Adapt tests to include real datasets in paths

Co-authored-by: Anna Scholtz <anna@scholtzan.net>
This commit is contained in:
Jeff Klukas 2021-03-02 09:13:08 -05:00 коммит произвёл GitHub
Родитель 5f19164105
Коммит 33e7b7499a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 45 добавлений и 21 удалений

Просмотреть файл

@ -328,6 +328,24 @@ class DryRun:
return {}
def get_dataset_labels(self):
"""Return the labels on the default dataset by dry running the SQL file."""
if self.sqlfile not in SKIP and not self.is_valid():
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
if self.sqlfile in SKIP:
print(f"\t...Ignoring dryrun results for {self.sqlfile}")
return {}
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "datasetLabels" in self.dry_run_result
):
return self.dry_run_result["datasetLabels"]
return {}
def is_valid(self):
"""Dry run the provided SQL file and check if valid."""
if self.dry_run_result is None:
@ -360,8 +378,9 @@ class DryRun:
def get_error(self):
"""Get specific errors for edge case handling."""
if "errors" in self.dry_run_result and len(self.dry_run_result["errors"]) == 1:
error = self.dry_run_result["errors"][0]
errors = self.dry_run_result.get("errors", None)
if errors and len(errors) == 1:
error = errors[0]
else:
error = None
if error and error.get("code", None) in [400, 403]:

Просмотреть файл

@ -5,32 +5,39 @@ import pytest
from bigquery_etl.dryrun import DryRun, Errors
@pytest.fixture
def tmp_query_path(tmp_path):
p = tmp_path / "telemetry_derived" / "mytable"
p.mkdir(parents=True)
return p
class TestDryRun:
def test_dry_run_sql_file(self, tmp_path):
query_file = tmp_path / "query.sql"
def test_dry_run_sql_file(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
dryrun = DryRun(str(query_file))
response = dryrun.dry_run_result
assert response["valid"]
def test_dry_run_invalid_sql_file(self, tmp_path):
query_file = tmp_path / "query.sql"
def test_dry_run_invalid_sql_file(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT INVALID 123")
dryrun = DryRun(str(query_file))
response = dryrun.dry_run_result
assert response["valid"] is False
def test_sql_file_valid(self, tmp_path):
query_file = tmp_path / "query.sql"
def test_sql_file_valid(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
dryrun = DryRun(str(query_file))
assert dryrun.is_valid()
def test_view_file_valid(self, tmp_path):
view_file = tmp_path / "view.sql"
def test_view_file_valid(self, tmp_query_path):
view_file = tmp_query_path / "view.sql"
view_file.write_text(
"""
SELECT
@ -45,15 +52,15 @@ class TestDryRun:
assert dryrun.get_error() is Errors.DATE_FILTER_NEEDED
assert dryrun.is_valid()
def test_sql_file_invalid(self, tmp_path):
query_file = tmp_path / "query.sql"
def test_sql_file_invalid(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT INVALID 123")
dryrun = DryRun(str(query_file))
assert dryrun.is_valid() is False
def test_get_referenced_tables_empty(self, tmp_path):
query_file = tmp_path / "query.sql"
def test_get_referenced_tables_empty(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text("SELECT 123")
dryrun = DryRun(str(query_file))
@ -70,9 +77,8 @@ class TestDryRun:
with pytest.raises(ValueError):
DryRun(sqlfile="invalid path").get_sql()
def test_get_referenced_tables(self, tmp_path):
os.makedirs(tmp_path / "telmetry_derived")
query_file = tmp_path / "telmetry_derived" / "query.sql"
def test_get_referenced_tables(self, tmp_query_path):
query_file = tmp_query_path / "query.sql"
query_file.write_text(
"SELECT * FROM telemetry_derived.clients_daily_v6 "
"WHERE submission_date = '2020-01-01'"
@ -83,7 +89,7 @@ class TestDryRun:
assert query_dryrun[0]["datasetId"] == "telemetry_derived"
assert query_dryrun[0]["tableId"] == "clients_daily_v6"
view_file = tmp_path / "telmetry_derived" / "view.sql"
view_file = tmp_query_path / "view.sql"
view_file.write_text(
"""
CREATE OR REPLACE VIEW
@ -121,9 +127,8 @@ class TestDryRun:
assert multiple_tables[1]["datasetId"] == "org_mozilla_firefox_stable"
assert multiple_tables[1]["tableId"] == "baseline_v1"
def test_get_error(self, tmp_path):
os.makedirs(tmp_path / "telemetry")
view_file = tmp_path / "telemetry" / "view.sql"
def test_get_error(self, tmp_query_path):
view_file = tmp_query_path / "view.sql"
view_file.write_text(
"""