bigquery-etl/bigquery_etl/dryrun.py

596 строки
21 KiB
Python

"""
Dry run query files.
Passes all queries to a Cloud Function that will run the
queries with the dry_run option enabled.
We could provision BigQuery credentials to the CircleCI job to allow it to run
the queries directly, but there is no way to restrict permissions such that
only dry runs can be performed. In order to reduce risk of CI or local users
accidentally running queries during tests and overwriting production data, we
proxy the queries through the dry run service endpoint.
"""
import glob
import json
import re
import sys
from enum import Enum
from os.path import basename, dirname, exists
from pathlib import Path
from typing import Optional, Set
from urllib.request import Request, urlopen
import click
import google.auth
from google.auth.transport.requests import Request as GoogleAuthRequest
from google.cloud import bigquery
from google.oauth2.id_token import fetch_id_token
from .config import ConfigLoader
from .metadata.parse_metadata import Metadata
from .util.common import render
try:
from functools import cached_property # type: ignore
except ImportError:
# python 3.7 compatibility
from backports.cached_property import cached_property # type: ignore
def get_credentials(auth_req: Optional[GoogleAuthRequest] = None):
"""Get GCP credentials."""
auth_req = auth_req or GoogleAuthRequest()
credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
credentials.refresh(auth_req)
return credentials
def get_id_token(dry_run_url=ConfigLoader.get("dry_run", "function"), credentials=None):
"""Get token to authenticate against Cloud Function."""
auth_req = GoogleAuthRequest()
credentials = credentials or get_credentials(auth_req)
if hasattr(credentials, "id_token"):
# Get token from default credentials for the current environment created via Cloud SDK run
id_token = credentials.id_token
else:
# If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set to service account JSON file,
# then ID token is acquired using this service account credentials.
id_token = fetch_id_token(auth_req, dry_run_url)
return id_token
class Errors(Enum):
"""DryRun errors that require special handling."""
READ_ONLY = 1
DATE_FILTER_NEEDED = 2
DATE_FILTER_NEEDED_AND_SYNTAX = 3
class DryRun:
"""Dry run SQL files."""
def __init__(
self,
sqlfile,
content=None,
strip_dml=False,
use_cloud_function=True,
client=None,
respect_skip=True,
sql_dir=ConfigLoader.get("default", "sql_dir"),
id_token=None,
credentials=None,
project=None,
dataset=None,
table=None,
):
"""Instantiate DryRun class."""
self.sqlfile = sqlfile
self.content = content
self.strip_dml = strip_dml
self.use_cloud_function = use_cloud_function
self.bq_client = client
self.respect_skip = respect_skip
self.dry_run_url = ConfigLoader.get("dry_run", "function")
self.sql_dir = sql_dir
self.id_token = (
id_token
if not use_cloud_function or id_token
else get_id_token(self.dry_run_url)
)
self.credentials = credentials
self.project = project
self.dataset = dataset
self.table = table
try:
self.metadata = Metadata.of_query_file(self.sqlfile)
except FileNotFoundError:
self.metadata = None
from bigquery_etl.cli.utils import is_authenticated
if not is_authenticated():
print(
"Authentication to GCP required. Run `gcloud auth login --update-adc` "
"and check that the project is set correctly."
)
sys.exit(1)
@cached_property
def client(self):
"""Get BigQuery client instance."""
if self.use_cloud_function:
return None
return self.bq_client or bigquery.Client(credentials=self.credentials)
@staticmethod
def skipped_files(sql_dir=ConfigLoader.get("default", "sql_dir")) -> Set[str]:
"""Return files skipped by dry run."""
default_sql_dir = Path(ConfigLoader.get("default", "sql_dir"))
sql_dir = Path(sql_dir)
file_pattern_re = re.compile(rf"^{re.escape(str(default_sql_dir))}/")
skip_files = {
file
for skip in ConfigLoader.get("dry_run", "skip", fallback=[])
for file in glob.glob(
file_pattern_re.sub(f"{str(sql_dir)}/", skip),
recursive=True,
)
}
# update skip list to include renamed queries in stage.
test_project = ConfigLoader.get("default", "test_project", fallback="")
file_pattern_re = re.compile(r"sql/([^\/]+)/([^/]+)(/?.*|$)")
skip_files.update(
[
file
for skip in ConfigLoader.get("dry_run", "skip", fallback=[])
for file in glob.glob(
file_pattern_re.sub(
lambda x: f"sql/{test_project}/{x.group(2)}_{x.group(1).replace('-', '_')}*{x.group(3)}",
skip,
),
recursive=True,
)
]
)
return skip_files
def skip(self):
"""Determine if dry run should be skipped."""
return self.respect_skip and self.sqlfile in self.skipped_files(
sql_dir=self.sql_dir
)
def get_sql(self):
"""Get SQL content."""
if exists(self.sqlfile):
file_path = Path(self.sqlfile)
sql = render(
file_path.name,
format=False,
template_folder=file_path.parent.absolute(),
)
else:
raise ValueError(f"Invalid file path: {self.sqlfile}")
if self.strip_dml:
sql = re.sub(
"CREATE OR REPLACE VIEW.*?AS",
"",
sql,
flags=re.DOTALL,
)
sql = re.sub(
"CREATE MATERIALIZED VIEW.*?AS",
"",
sql,
flags=re.DOTALL,
)
return sql
@cached_property
def dry_run_result(self):
"""Dry run the provided SQL file."""
if self.content:
sql = self.content
else:
sql = self.get_sql()
if self.metadata:
# use metadata to rewrite date-type query params as submission_date
date_params = [
query_param
for query_param in (
self.metadata.scheduling.get("date_partition_parameter"),
*(
param.split(":", 1)[0]
for param in self.metadata.scheduling.get("parameters", [])
if re.fullmatch(r"[^:]+:DATE:{{.*ds.*}}", param)
),
)
if query_param and query_param != "submission_date"
]
if date_params:
pattern = re.compile(
"@("
+ "|".join(date_params)
# match whole query parameter names
+ ")(?![a-zA-Z0-9_])"
)
sql = pattern.sub("@submission_date", sql)
project = basename(dirname(dirname(dirname(self.sqlfile))))
dataset = basename(dirname(dirname(self.sqlfile)))
try:
if self.use_cloud_function:
json_data = {
"project": self.project or project,
"dataset": self.dataset or dataset,
"query": sql,
}
if self.table:
json_data["table"] = self.table
r = urlopen(
Request(
self.dry_run_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.id_token}",
},
data=json.dumps(json_data).encode("utf8"),
method="POST",
)
)
return json.load(r)
else:
self.client.project = project
job_config = bigquery.QueryJobConfig(
dry_run=True,
use_query_cache=False,
default_dataset=f"{project}.{dataset}",
query_parameters=[
bigquery.ScalarQueryParameter(
"submission_date", "DATE", "2019-01-01"
)
],
)
job = self.client.query(sql, job_config=job_config)
try:
dataset_labels = self.client.get_dataset(job.default_dataset).labels
except Exception as e:
# Most users do not have bigquery.datasets.get permission in
# moz-fx-data-shared-prod
# This should not prevent the dry run from running since the dataset
# labels are usually not required
if "Permission bigquery.datasets.get denied on dataset" in str(e):
dataset_labels = []
else:
raise e
result = {
"valid": True,
"referencedTables": [
ref.to_api_repr() for ref in job.referenced_tables
],
"schema": (
job._properties.get("statistics", {})
.get("query", {})
.get("schema", {})
),
"datasetLabels": dataset_labels,
}
if (
self.project is not None
and self.table is not None
and self.dataset is not None
):
table = self.client.get_table(
f"{self.project}.{self.dataset}.{self.table}"
)
result["tableMetadata"] = {
"tableType": table.table_type,
"friendlyName": table.friendly_name,
"schema": {
"fields": [field.to_api_repr() for field in table.schema]
},
}
return result
except Exception as e:
print(f"{self.sqlfile!s:59} ERROR\n", e)
return None
def get_referenced_tables(self):
"""Return referenced tables by dry running the SQL file."""
if not self.skip() and not self.is_valid():
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
if self.skip():
print(f"\t...Ignoring dryrun results for {self.sqlfile}")
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "referencedTables" in self.dry_run_result
):
return self.dry_run_result["referencedTables"]
# Handle views that require a date filter
if (
self.dry_run_result
and self.strip_dml
and self.get_error() == Errors.DATE_FILTER_NEEDED
):
# Since different queries require different partition filters
# (submission_date, crash_date, timestamp, submission_timestamp, ...)
# We can extract the filter name from the error message
# (by capturing the next word after "column(s)")
# Example error:
# "Cannot query over table <table_name> without a filter over column(s)
# <date_filter_name> that can be used for partition elimination."
error = self.dry_run_result["errors"][0].get("message", "")
date_filter = find_next_word("column(s)", error)
if "date" in date_filter:
filtered_content = (
f"{self.get_sql()}WHERE {date_filter} > current_date()"
)
if (
DryRun(
self.sqlfile,
filtered_content,
client=self.client,
id_token=self.id_token,
).get_error()
== Errors.DATE_FILTER_NEEDED_AND_SYNTAX
):
# If the date filter (e.g. WHERE crash_date > current_date())
# is added to a query that already has a WHERE clause,
# it will throw an error. To fix this, we need to
# append 'AND' instead of 'WHERE'
filtered_content = (
f"{self.get_sql()}AND {date_filter} > current_date()"
)
if "timestamp" in date_filter:
filtered_content = (
f"{self.get_sql()}WHERE {date_filter} > current_timestamp()"
)
if (
DryRun(
sqlfile=self.sqlfile,
content=filtered_content,
client=self.client,
id_token=self.id_token,
).get_error()
== Errors.DATE_FILTER_NEEDED_AND_SYNTAX
):
filtered_content = (
f"{self.get_sql()}AND {date_filter} > current_timestamp()"
)
stripped_dml_result = DryRun(
sqlfile=self.sqlfile,
content=filtered_content,
client=self.client,
id_token=self.id_token,
)
if (
stripped_dml_result.get_error() is None
and "referencedTables" in stripped_dml_result.dry_run_result
):
return stripped_dml_result.dry_run_result["referencedTables"]
return []
def get_schema(self):
"""Return the query schema by dry running the SQL file."""
if not self.skip() and not self.is_valid():
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
if self.skip():
print(f"\t...Ignoring dryrun results for {self.sqlfile}")
return {}
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "schema" in self.dry_run_result
):
return self.dry_run_result["schema"]
return {}
def get_table_schema(self):
"""Return the schema of the provided table."""
if not self.skip() and not self.is_valid():
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
if self.skip():
print(f"\t...Ignoring dryrun results for {self.sqlfile}")
return {}
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "tableMetadata" in self.dry_run_result
):
return self.dry_run_result["tableMetadata"]["schema"]
return []
def get_dataset_labels(self):
"""Return the labels on the default dataset by dry running the SQL file."""
if not self.skip() and not self.is_valid():
raise Exception(f"Error when dry running SQL file {self.sqlfile}")
if self.skip():
print(f"\t...Ignoring dryrun results for {self.sqlfile}")
return {}
if (
self.dry_run_result
and self.dry_run_result["valid"]
and "datasetLabels" in self.dry_run_result
):
return self.dry_run_result["datasetLabels"]
return {}
def is_valid(self):
"""Dry run the provided SQL file and check if valid."""
if self.dry_run_result is None:
return False
if self.dry_run_result["valid"]:
print(f"{self.sqlfile!s:59} OK")
elif self.get_error() == Errors.READ_ONLY:
# We want the dryrun service to only have read permissions, so
# we expect CREATE VIEW and CREATE TABLE to throw specific
# exceptions.
print(f"{self.sqlfile!s:59} OK")
elif self.get_error() == Errors.DATE_FILTER_NEEDED and self.strip_dml:
# With strip_dml flag, some queries require a partition filter
# (submission_date, submission_timestamp, etc.) to run
# We mark these requests as valid and add a date filter
# in get_referenced_table()
print(f"{self.sqlfile!s:59} OK but DATE FILTER NEEDED")
else:
print(f"{self.sqlfile!s:59} ERROR\n", self.dry_run_result["errors"])
return False
return True
def errors(self):
"""Dry run the provided SQL file and return errors."""
if self.dry_run_result is None:
return []
return self.dry_run_result.get("errors", [])
def get_error(self) -> Optional[Errors]:
"""Get specific errors for edge case handling."""
errors = self.errors()
if len(errors) != 1:
return None
error = errors[0]
if error and error.get("code") in [400, 403]:
error_message = error.get("message", "")
if (
"does not have bigquery.tables.create permission for dataset"
in error_message
or "Permission bigquery.tables.create denied" in error_message
or "Permission bigquery.datasets.update denied" in error_message
):
return Errors.READ_ONLY
if "without a filter over column(s)" in error_message:
return Errors.DATE_FILTER_NEEDED
if (
"Syntax error: Expected end of input but got keyword WHERE"
in error_message
):
return Errors.DATE_FILTER_NEEDED_AND_SYNTAX
return None
def validate_schema(self):
"""Check whether schema is valid."""
# delay import to prevent circular imports in 'bigquery_etl.schema'
from .schema import SCHEMA_FILE, Schema
if (
self.skip()
or basename(self.sqlfile) == "script.sql"
or str(self.sqlfile).endswith(".py")
): # noqa E501
print(f"\t...Ignoring schema validation for {self.sqlfile}")
return True
query_file_path = Path(self.sqlfile)
query_schema = Schema.from_json(self.get_schema())
if self.errors():
# ignore file when there are errors that self.get_schema() did not raise
click.echo(f"\t...Ignoring schema validation for {self.sqlfile}")
return True
existing_schema_path = query_file_path.parent / SCHEMA_FILE
if not existing_schema_path.is_file():
click.echo(f"No schema file defined for {query_file_path}", err=True)
return True
table_name = query_file_path.parent.name
dataset_name = query_file_path.parent.parent.name
project_name = query_file_path.parent.parent.parent.name
partitioned_by = None
if (
self.metadata
and self.metadata.bigquery
and self.metadata.bigquery.time_partitioning
):
partitioned_by = self.metadata.bigquery.time_partitioning.field
table_schema = Schema.for_table(
project_name,
dataset_name,
table_name,
client=self.client,
id_token=self.id_token,
partitioned_by=partitioned_by,
)
# This check relies on the new schema being deployed to prod
if not query_schema.compatible(table_schema):
click.echo(
click.style(
f"ERROR: Schema for query in {query_file_path} "
f"incompatible with schema deployed for "
f"{project_name}.{dataset_name}.{table_name}\n"
f"Did you deploy new the schema to prod yet?",
fg="red",
),
err=True,
)
return False
else:
existing_schema = Schema.from_schema_file(existing_schema_path)
if not existing_schema.equal(query_schema):
click.echo(
click.style(
f"ERROR: Schema defined in {existing_schema_path} "
f"incompatible with query {query_file_path}",
fg="red",
),
err=True,
)
return False
click.echo(f"Schemas for {query_file_path} are valid.")
return True
def sql_file_valid(sqlfile):
"""Dry run SQL files."""
return DryRun(sqlfile).is_valid()
def find_next_word(target, source):
"""Find the next word in a string."""
split = source.split()
for i, w in enumerate(split):
if w == target:
# get the next word, and remove quotations from column name
return split[i + 1].replace("'", "")