From 33e7b7499a64c6278c2e9b5003a3af8f28d8ef1b Mon Sep 17 00:00:00 2001 From: Jeff Klukas Date: Tue, 2 Mar 2021 09:13:08 -0500 Subject: [PATCH] Bug 1695654: Add a method to fetch dataset labels from dryrun service results (#1858) * Add a method to fetch dataset labels from dryrun service results * Update bigquery_etl/dryrun.py Co-authored-by: Anna Scholtz * Adapt tests to include real datasets in paths Co-authored-by: Anna Scholtz --- bigquery_etl/dryrun.py | 23 ++++++++++++++++++++-- tests/test_dryrun.py | 43 +++++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/bigquery_etl/dryrun.py b/bigquery_etl/dryrun.py index 17e7ac1efc..9e93d8e117 100644 --- a/bigquery_etl/dryrun.py +++ b/bigquery_etl/dryrun.py @@ -328,6 +328,24 @@ class DryRun: return {} + def get_dataset_labels(self): + """Return the labels on the default dataset by dry running the SQL file.""" + if self.sqlfile not in SKIP and not self.is_valid(): + raise Exception(f"Error when dry running SQL file {self.sqlfile}") + + if self.sqlfile in SKIP: + print(f"\t...Ignoring dryrun results for {self.sqlfile}") + return {} + + if ( + self.dry_run_result + and self.dry_run_result["valid"] + and "datasetLabels" in self.dry_run_result + ): + return self.dry_run_result["datasetLabels"] + + return {} + def is_valid(self): """Dry run the provided SQL file and check if valid.""" if self.dry_run_result is None: @@ -360,8 +378,9 @@ class DryRun: def get_error(self): """Get specific errors for edge case handling.""" - if "errors" in self.dry_run_result and len(self.dry_run_result["errors"]) == 1: - error = self.dry_run_result["errors"][0] + errors = self.dry_run_result.get("errors", None) + if errors and len(errors) == 1: + error = errors[0] else: error = None if error and error.get("code", None) in [400, 403]: diff --git a/tests/test_dryrun.py b/tests/test_dryrun.py index 36dd8b0479..a839d59390 100644 --- a/tests/test_dryrun.py +++ b/tests/test_dryrun.py @@ -5,32 +5,39 @@ import pytest from bigquery_etl.dryrun import DryRun, Errors +@pytest.fixture +def tmp_query_path(tmp_path): + p = tmp_path / "telemetry_derived" / "mytable" + p.mkdir(parents=True) + return p + + class TestDryRun: - def test_dry_run_sql_file(self, tmp_path): - query_file = tmp_path / "query.sql" + def test_dry_run_sql_file(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text("SELECT 123") dryrun = DryRun(str(query_file)) response = dryrun.dry_run_result assert response["valid"] - def test_dry_run_invalid_sql_file(self, tmp_path): - query_file = tmp_path / "query.sql" + def test_dry_run_invalid_sql_file(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text("SELECT INVALID 123") dryrun = DryRun(str(query_file)) response = dryrun.dry_run_result assert response["valid"] is False - def test_sql_file_valid(self, tmp_path): - query_file = tmp_path / "query.sql" + def test_sql_file_valid(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text("SELECT 123") dryrun = DryRun(str(query_file)) assert dryrun.is_valid() - def test_view_file_valid(self, tmp_path): - view_file = tmp_path / "view.sql" + def test_view_file_valid(self, tmp_query_path): + view_file = tmp_query_path / "view.sql" view_file.write_text( """ SELECT @@ -45,15 +52,15 @@ class TestDryRun: assert dryrun.get_error() is Errors.DATE_FILTER_NEEDED assert dryrun.is_valid() - def test_sql_file_invalid(self, tmp_path): - query_file = tmp_path / "query.sql" + def test_sql_file_invalid(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text("SELECT INVALID 123") dryrun = DryRun(str(query_file)) assert dryrun.is_valid() is False - def test_get_referenced_tables_empty(self, tmp_path): - query_file = tmp_path / "query.sql" + def test_get_referenced_tables_empty(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text("SELECT 123") dryrun = DryRun(str(query_file)) @@ -70,9 +77,8 @@ class TestDryRun: with pytest.raises(ValueError): DryRun(sqlfile="invalid path").get_sql() - def test_get_referenced_tables(self, tmp_path): - os.makedirs(tmp_path / "telmetry_derived") - query_file = tmp_path / "telmetry_derived" / "query.sql" + def test_get_referenced_tables(self, tmp_query_path): + query_file = tmp_query_path / "query.sql" query_file.write_text( "SELECT * FROM telemetry_derived.clients_daily_v6 " "WHERE submission_date = '2020-01-01'" @@ -83,7 +89,7 @@ class TestDryRun: assert query_dryrun[0]["datasetId"] == "telemetry_derived" assert query_dryrun[0]["tableId"] == "clients_daily_v6" - view_file = tmp_path / "telmetry_derived" / "view.sql" + view_file = tmp_query_path / "view.sql" view_file.write_text( """ CREATE OR REPLACE VIEW @@ -121,9 +127,8 @@ class TestDryRun: assert multiple_tables[1]["datasetId"] == "org_mozilla_firefox_stable" assert multiple_tables[1]["tableId"] == "baseline_v1" - def test_get_error(self, tmp_path): - os.makedirs(tmp_path / "telemetry") - view_file = tmp_path / "telemetry" / "view.sql" + def test_get_error(self, tmp_query_path): + view_file = tmp_query_path / "view.sql" view_file.write_text( """