Generate query with shredder mitigation (#6060)

* Auxiliary functions required to generate the query for a backfill with shredder mitigation.

* Exception handling.

* isort & docstrings.

* Apply flake8 to test file.

* Remove variable assignment to different types.

* Make search case insensitive in function.

* Add test cases for function and update naming in a funcion's parameters for clarity.

* Update bigquery_etl/backfill/shredder_mitigation.py

Co-authored-by: Leli <33942105+lelilia@users.noreply.github.com>

* Add test cases for missing parameters or not matching parameters where expected. minimize the calls for get_bigquery_type().

* Encapsulate actions to generate and run custom queries to generate the subsets for shredder mitigation.

* Query template for shredder mitigation.

* Query template for shredder mitigation and formatting.

* Add check for "GROUP BY 1, 2, 3", improve code readibility, remove unnecesary properties in classes.

* Test coverage. Check for "GROUP BY 1, 2, 3", improve readibility, remove unrequired properties in class Subset.

* Increase test coverage. Expand DataType INTEGER required for UNION queries.

* Increase test coverage. Expand DataType INTEGER required for UNION queries.

* Separate INTEFER and NUMERIC types.

* Move util functions and convert method to property, both to resolve a circular import. Adjust tests. Update function return and tests.

* Adding backfill_date to exception message. Formatting.

* Adding backfill_date to exception message. Formatting.

---------

Co-authored-by: Leli <33942105+lelilia@users.noreply.github.com>
This commit is contained in:
Lucia 2024-09-06 16:21:56 +02:00 коммит произвёл GitHub
Родитель 1570bf1682
Коммит 9fc513ba9f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 1406 добавлений и 243 удалений

Просмотреть файл

@ -0,0 +1,23 @@
-- Query generated using a template for shredder mitigation.
WITH {{ new_version_cte }} AS (
{{ new_version }}
),
{{ new_agg_cte }} AS (
{{ new_agg }}
),
{{ previous_agg_cte }} AS (
{{ previous_agg }}
),
{{ shredded_cte }} AS (
{{ shredded }}
)
SELECT
{{ final_select }}
FROM
{{ new_version_cte }}
UNION ALL
SELECT
{{ final_select}}
FROM
{{ shredded_cte }}
;

Просмотреть файл

@ -1,14 +1,32 @@
"""Generate a query to backfill an aggregate with shredder mitigation."""
"""Generate a query with shredder mitigation."""
from datetime import date, datetime, time, timedelta
import os
import re
from datetime import date
from datetime import datetime as dt
from datetime import time, timedelta
from enum import Enum
from pathlib import Path
from types import NoneType
from typing import Any, Optional, Tuple
import attr
import click
from dateutil import parser
from gcloud.exceptions import NotFound # type: ignore
from google.cloud import bigquery
from jinja2 import Environment, FileSystemLoader
from bigquery_etl.format_sql.formatter import reformat
from bigquery_etl.metadata.parse_metadata import METADATA_FILE, Metadata
from bigquery_etl.util.common import extract_last_group_by_from_query, write_sql
PREVIOUS_DATE = (dt.now() - timedelta(days=2)).date()
SUFFIX = dt.now().strftime("%Y%m%d%H%M%S")
TEMP_DATASET = "tmp"
SUFFIX = datetime.now().strftime("%Y%m%d%H%M%S")
PREVIOUS_DATE = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
THIS_PATH = Path(os.path.dirname(__file__))
DEFAULT_PROJECT_ID = "moz-fx-data-shared-prod"
QUERY_WITH_MITIGATION_NAME = "query_with_shredder_mitigation"
class ColumnType(Enum):
@ -29,97 +47,259 @@ class ColumnStatus(Enum):
class DataTypeGroup(Enum):
"""Data types in BigQuery. Not including ARRAY and STRUCT as these are not expected in aggregate tables."""
"""Data types in BigQuery. Not supported/expected in aggregates: TIMESTAMP, ARRAY, STRUCT."""
STRING = ("STRING", "BYTES")
BOOLEAN = "BOOLEAN"
INTEGER = ("INTEGER", "INT64", "INT", "SMALLINT", "TINYINT", "BYTEINT")
NUMERIC = (
"INTEGER",
"NUMERIC",
"BIGNUMERIC",
"DECIMAL",
"INT64",
"INT",
"SMALLINT",
"BIGINT",
"TINYINT",
"BYTEINT",
)
FLOAT = ("FLOAT",)
DATE = ("DATE", "DATETIME", "TIME", "TIMESTAMP")
FLOAT = "FLOAT"
DATE = "DATE"
DATETIME = "DATETIME"
TIME = "TIME"
TIMESTAMP = "TIMESTAMP"
UNDETERMINED = "None"
@attr.define(eq=True)
class Column:
"""Representation of a column in a query, with relevant details for shredder mitigation."""
def __init__(self, name, data_type, column_type, status):
"""Initialize class with required attributes."""
self.name = name
self.data_type = data_type
self.column_type = column_type
self.status = status
name: str
data_type: DataTypeGroup = attr.field(default=DataTypeGroup.UNDETERMINED)
column_type: ColumnType = attr.field(default=ColumnType.UNDETERMINED)
status: ColumnStatus = attr.field(default=ColumnStatus.UNDETERMINED)
def __eq__(self, other):
"""Return attributes only if the referenced object is of type Column."""
if not isinstance(other, Column):
return NotImplemented
return (
self.name == other.name
and self.data_type == other.data_type
and self.column_type == other.column_type
and self.status == other.status
"""Validate the type of the attributes."""
@data_type.validator
def validate_data_type(self, attribute, value):
"""Check that the type of data_type is as expected."""
if not isinstance(value, DataTypeGroup):
raise ValueError(f"Invalid {value} with type: {type(value)}.")
@column_type.validator
def validate_column_type(self, attribute, value):
"""Check that the type of parameter column_type is as expected."""
if not isinstance(value, ColumnType):
raise ValueError(f"Invalid data type for: {value}.")
@status.validator
def validate_status(self, attribute, value):
"""Check that the type of parameter column_status is as expected."""
if not isinstance(value, ColumnStatus):
raise ValueError(f"Invalid data type for: {value}.")
@attr.define(eq=True)
class Subset:
"""Representation of a subset/CTEs in the query and the actions related to this subset."""
client: bigquery.Client
destination_table: str = attr.field(default="")
query_cte: str = attr.field(default="")
dataset: str = attr.field(default=TEMP_DATASET)
project_id: str = attr.field(default=DEFAULT_PROJECT_ID)
expiration_days: Optional[float] = attr.field(default=None)
@property
def expiration_ms(self) -> Optional[float]:
"""Convert partition expiration from days to milliseconds."""
if self.expiration_days is None:
return None
return int(self.expiration_days * 86_400_000)
@property
def version(self):
"""Return the version of the destination table."""
match = re.search(r"v(\d+)$", self.destination_table)
try:
version = int(match.group(1))
if not isinstance(version, int):
raise click.ClickException(
f"{self.destination_table} must end with a positive integer."
)
return version
except (AttributeError, TypeError):
raise click.ClickException(
f"Invalid or missing table version in {self.destination_table}."
)
@property
def full_table_id(self):
"""Return the full id of the destination table."""
return f"{self.project_id}.{self.dataset}.{self.destination_table}"
@property
def query_path(self):
"""Return the full path of the query.sql file associated with the subset."""
sql_path = (
Path("sql")
/ self.project_id
/ self.dataset
/ self.destination_table
/ "query.sql"
)
if not os.path.isfile(sql_path):
click.echo(
click.style(f"Required file not found: {sql_path}.", fg="yellow")
)
return None
return sql_path
def __repr__(self):
"""Return a string representation of the object."""
return f"Column(name={self.name}, data_type={self.data_type}, column_type={self.column_type}, status={self.status})"
@property
def partitioning(self):
"""Return the partition details of the destination table."""
metadata = Metadata.from_file(
Path("sql")
/ self.project_id
/ self.dataset
/ self.destination_table
/ METADATA_FILE
)
if metadata.bigquery and metadata.bigquery.time_partitioning:
partitioning = {
"type": metadata.bigquery.time_partitioning.type.name,
"field": metadata.bigquery.time_partitioning.field,
}
else:
partitioning = {"type": None, "field": None}
return partitioning
def generate_query(
self,
select_list,
from_clause,
where_clause=None,
group_by_clause=None,
order_by_clause=None,
having_clause=None,
):
"""Build query to populate the table."""
if not select_list or not from_clause:
raise click.ClickException(
f"Missing required clause to generate query.\n"
f"Actuals: SELECT: {select_list}, FROM: {self.full_table_id}"
)
query = f"SELECT {', '.join(map(str, select_list))}"
query += f" FROM {from_clause}" if from_clause is not None else ""
query += f" WHERE {where_clause}" if where_clause is not None else ""
query += f" GROUP BY {group_by_clause}" if group_by_clause is not None else ""
query += (
f" HAVING {having_clause}"
if having_clause is not None and group_by_clause is not None
else ""
)
query += f" ORDER BY {order_by_clause}" if order_by_clause is not None else ""
return query
def get_query_path_results(
self,
backfill_date: date = PREVIOUS_DATE,
row_limit: Optional[int] = None,
**kwargs,
) -> list[dict[str, Any]]:
"""Run the query in sql_path & return result or number of rows requested."""
having_clause = None
for key, value in kwargs.items():
if key.lower() == "having_clause":
having_clause = f"{value}"
with open(self.query_path, "r") as file:
sql_text = file.read().strip()
if sql_text.endswith(";"):
sql_text = sql_text[:-1]
if having_clause:
sql_text = f"{sql_text} {having_clause}"
if row_limit:
sql_text = f"{sql_text} LIMIT {row_limit}"
partition_field = self.partitioning["field"]
partition_type = (
"DATE" if self.partitioning["type"] == "DAY" else self.partitioning["type"]
)
parameters = None
if partition_field is not None:
parameters = [
bigquery.ScalarQueryParameter(
partition_field, partition_type, backfill_date
),
]
try:
query_job = self.client.query(
query=sql_text,
job_config=bigquery.QueryJobConfig(
query_parameters=parameters,
use_legacy_sql=False,
dry_run=False,
use_query_cache=False,
),
)
query_results = query_job.result()
except NotFound as e:
raise click.ClickException(
f"Unable to query data for {backfill_date}. Table {self.full_table_id} not found."
) from e
rows = [dict(row) for row in query_results]
return rows
def compare_current_and_previous_version(
self,
date_partition_parameter,
):
"""Generate and run a data check to compare existing and backfilled data."""
return NotImplemented
def get_bigquery_type(value) -> DataTypeGroup:
"""Return the datatype of a value, grouping similar types."""
date_formats = [
("%Y-%m-%d", date),
("%Y-%m-%d %H:%M:%S", datetime),
("%Y-%m-%dT%H:%M:%S", datetime),
("%Y-%m-%dT%H:%M:%SZ", datetime),
("%Y-%m-%d %H:%M:%S UTC", datetime),
("%H:%M:%S", time),
]
for format, dtype in date_formats:
try:
if dtype == time:
parsed_to_time = datetime.strptime(value, format).time()
if isinstance(parsed_to_time, time):
return DataTypeGroup.DATE
else:
parsed_to_date = datetime.strptime(value, format)
if isinstance(parsed_to_date, dtype):
return DataTypeGroup.DATE
except (ValueError, TypeError):
continue
"""Find the datatype of a value, grouping similar types."""
if isinstance(value, dt):
return DataTypeGroup.DATETIME
try:
if isinstance(dt.strptime(value, "%H:%M:%S").time(), time):
return DataTypeGroup.TIME
except (ValueError, TypeError, AttributeError):
pass
try:
value_parsed = parser.isoparse(value.replace(" UTC", "Z").replace(" ", "T"))
if (
isinstance(value_parsed, dt)
and value_parsed.time() == time(0, 0)
and isinstance(dt.strptime(value, "%Y-%m-%d"), date)
):
return DataTypeGroup.DATE
if isinstance(value_parsed, dt) and value_parsed.tzinfo is None:
return DataTypeGroup.DATETIME
if isinstance(value_parsed, dt) and value_parsed.tzinfo is not None:
return DataTypeGroup.TIMESTAMP
except (ValueError, TypeError, AttributeError):
pass
if isinstance(value, time):
return DataTypeGroup.DATE
return DataTypeGroup.TIME
if isinstance(value, date):
return DataTypeGroup.DATE
elif isinstance(value, bool):
if isinstance(value, bool):
return DataTypeGroup.BOOLEAN
if isinstance(value, int):
return DataTypeGroup.NUMERIC
elif isinstance(value, float):
return DataTypeGroup.INTEGER
if isinstance(value, float):
return DataTypeGroup.FLOAT
elif isinstance(value, str) or isinstance(value, bytes):
if isinstance(value, (str, bytes)):
return DataTypeGroup.STRING
elif isinstance(value, NoneType):
if isinstance(value, NoneType):
return DataTypeGroup.UNDETERMINED
else:
raise ValueError(f"Unsupported data type: {type(value)}")
raise ValueError(f"Unsupported data type: {type(value)}")
def classify_columns(
new_row: dict, existing_dimension_columns: list, new_dimension_columns: list
) -> tuple[list[Column], list[Column], list[Column], list[Column], list[Column]]:
"""Compare the new row with the existing columns and return the list of common, added and removed columns by type."""
"""Compare new row with existing columns & return common, added & removed columns."""
common_dimensions = []
added_dimensions = []
removed_dimensions = []
@ -128,8 +308,10 @@ def classify_columns(
if not new_row or not existing_dimension_columns or not new_dimension_columns:
raise click.ClickException(
f"Missing required parameters. Received: new_row= {new_row}\n"
f"existing_dimension_columns= {existing_dimension_columns},\nnew_dimension_columns= {new_dimension_columns}."
f"\n\nMissing one or more required parameters. Received:"
f"\nnew_row= {new_row}"
f"\nexisting_dimension_columns= {existing_dimension_columns},"
f"\nnew_dimension_columns= {new_dimension_columns}."
)
missing_dimensions = [
@ -137,7 +319,8 @@ def classify_columns(
]
if not len(missing_dimensions) == 0:
raise click.ClickException(
f"Inconsistent parameters. Columns in new dimensions not found in new row: {missing_dimensions}"
f"Existing dimensions don't match columns retrieved by query."
f" Missing {missing_dimensions}."
)
for key in existing_dimension_columns:
@ -175,7 +358,9 @@ def classify_columns(
key not in existing_dimension_columns
and key not in new_dimension_columns
and (
value_type is DataTypeGroup.NUMERIC or value_type is DataTypeGroup.FLOAT
value_type is DataTypeGroup.INTEGER
or value_type is DataTypeGroup.FLOAT
or value_type is DataTypeGroup.NUMERIC
)
):
# Columns that are not in the previous or new list of grouping columns are metrics.
@ -212,3 +397,311 @@ def classify_columns(
metrics_sorted,
undefined_sorted,
)
def generate_query_with_shredder_mitigation(
client, project_id, dataset, destination_table, backfill_date=PREVIOUS_DATE
) -> Tuple[Path, str]:
"""Generate a query to backfill with shredder mitigation."""
query_with_mitigation_path = Path("sql") / project_id
# Find query files and grouping of previous and new queries.
new = Subset(client, destination_table, "new_version", dataset, project_id, None)
if new.version < 2:
raise click.ClickException(
f"The new version of the table is expected >= 2. Actual is {new.version}."
)
destination_table_previous_version = (
f"{destination_table[:-len(str(new.version))]}{new.version-1}"
)
previous = Subset(
client,
destination_table_previous_version,
"previous",
dataset,
project_id,
None,
)
new_group_by = extract_last_group_by_from_query(sql_path=new.query_path)
previous_group_by = extract_last_group_by_from_query(sql_path=previous.query_path)
# Check that previous query exists and GROUP BYs are valid in both queries.
integers_in_group_by = False
for e in previous_group_by + new_group_by:
try:
int(e)
integers_in_group_by = True
except ValueError:
continue
if (
"ALL" in previous_group_by
or "ALL" in new_group_by
or not all(isinstance(e, str) for e in previous_group_by)
or not all(isinstance(e, str) for e in new_group_by)
or not previous_group_by
or not new_group_by
or integers_in_group_by
):
raise click.ClickException(
"GROUP BY must use an explicit list of columns. "
"Avoid expressions like `GROUP BY ALL` or `GROUP BY 1, 2, 3`."
)
# Identify columns common to both queries and columns new. This excludes removed columns.
sample_rows = new.get_query_path_results(
backfill_date=backfill_date,
row_limit=1,
having_clause=f"HAVING {' IS NOT NULL AND '.join(new_group_by)} IS NOT NULL",
)
if not sample_rows:
sample_rows = new.get_query_path_results(
backfill_date=backfill_date, row_limit=1
)
try:
new_table_row = sample_rows[0]
(
common_dimensions,
added_dimensions,
removed_dimensions,
metrics,
undetermined_columns,
) = classify_columns(new_table_row, previous_group_by, new_group_by)
except TypeError as e:
raise click.ClickException(
f"Table {destination_table} did not return any rows for {backfill_date}.\n{e}"
)
if not common_dimensions or not added_dimensions or not metrics:
raise click.ClickException(
"The process requires that previous & query have at least one dimension in common,"
" one dimension added and one metric."
)
# Get the new query.
with open(new.query_path, "r") as file:
new_query = file.read().strip()
# Aggregate previous data and new query results using common dimensions.
new_agg = Subset(
client, destination_table, "new_agg", TEMP_DATASET, project_id, None
)
previous_agg = Subset(
client,
destination_table_previous_version,
"previous_agg",
TEMP_DATASET,
project_id,
None,
)
common_select = (
[previous.partitioning["field"]]
+ [
f"COALESCE({dim.name}, '??') AS {dim.name}"
for dim in common_dimensions
if (
dim.name != previous.partitioning["field"]
and dim.data_type == DataTypeGroup.STRING
)
]
+ [
f"COALESCE({dim.name}, -999) AS {dim.name}"
for dim in common_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type in (DataTypeGroup.INTEGER, DataTypeGroup.FLOAT)
)
]
+ [
dim.name
for dim in common_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE)
)
]
+ [f"SUM({metric.name}) AS {metric.name}" for metric in metrics]
)
new_agg_query = new_agg.generate_query(
select_list=common_select,
from_clause=f"{new.query_cte}",
group_by_clause="ALL",
)
previous_agg_query = previous_agg.generate_query(
select_list=common_select,
from_clause=f"`{previous.full_table_id}`",
where_clause=f"{previous.partitioning['field']} = @{previous.partitioning['field']}",
group_by_clause="ALL",
)
# Calculate shredder impact.
shredded = Subset(
client, destination_table, "shredded", TEMP_DATASET, project_id, None
)
# Set values to NULL for the supported types.
shredded_select = (
[f"{previous_agg.query_cte}.{new.partitioning['field']}"]
+ [
f"{previous_agg.query_cte}.{dim.name}"
for dim in common_dimensions
if (dim.name != new.partitioning["field"])
]
+ [
f"CAST(NULL AS {DataTypeGroup.STRING.name}) AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type == DataTypeGroup.STRING
)
]
+ [
f"CAST(NULL AS {DataTypeGroup.BOOLEAN.name}) AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type == DataTypeGroup.BOOLEAN
)
]
# This doesn't convert data or dtypes, it's only used to cast NULLs for UNION queries.
+ [
f"CAST(NULL AS {DataTypeGroup.DATE.name}) AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type == DataTypeGroup.DATE
)
]
+ [
f"CAST(NULL AS {DataTypeGroup.INTEGER.name}) AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type == DataTypeGroup.INTEGER
)
]
+ [
f"CAST(NULL AS {DataTypeGroup.FLOAT.name}) AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type == DataTypeGroup.FLOAT
)
]
+ [
f"NULL AS {dim.name}"
for dim in added_dimensions
if (
dim.name != new.partitioning["field"]
and dim.data_type
not in (
DataTypeGroup.STRING,
DataTypeGroup.BOOLEAN,
DataTypeGroup.INTEGER,
DataTypeGroup.FLOAT,
)
)
]
+ [
f"{previous_agg.query_cte}.{metric.name} - IFNULL({new_agg.query_cte}.{metric.name}, 0)"
f" AS {metric.name}"
for metric in metrics
if metric.data_type != DataTypeGroup.FLOAT
]
+ [
f"ROUND({previous_agg.query_cte}.{metric.name}, 3) - "
f"ROUND(IFNULL({new_agg.query_cte}.{metric.name}, 0), 3) AS {metric.name}"
for metric in metrics
if metric.data_type == DataTypeGroup.FLOAT
]
)
shredded_join = " AND ".join(
[
f"{previous_agg.query_cte}.{previous.partitioning['field']} ="
f" {new_agg.query_cte}.{new.partitioning['field']}"
]
+ [
f"{previous_agg.query_cte}.{dim.name} = {new_agg.query_cte}.{dim.name}"
for dim in common_dimensions
if dim.name != previous.partitioning["field"]
and dim.data_type not in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE)
]
+ [
f"({previous_agg.query_cte}.{dim.name} = {new_agg.query_cte}.{dim.name} OR"
f" ({previous_agg.query_cte}.{dim.name} IS NULL"
f" AND {new_agg.query_cte}.{dim.name} IS NULL))" # Compare null values.
for dim in common_dimensions
if dim.name != previous.partitioning["field"]
and dim.data_type in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE)
]
)
shredded_query = shredded.generate_query(
select_list=shredded_select,
from_clause=f"{previous_agg.query_cte} LEFT JOIN {new_agg.query_cte} ON {shredded_join} ",
where_clause=" OR ".join(
[
f"{previous_agg.query_cte}.{metric.name} > IFNULL({new_agg.query_cte}.{metric.name}, 0)"
for metric in metrics
]
),
)
combined_list = (
[dim.name for dim in common_dimensions]
+ [dim.name for dim in added_dimensions]
+ [metric.name for metric in metrics]
)
final_select = f"{', '.join(combined_list)}"
click.echo(
click.style(
f"""Generating query with shredder mitigation and the following columns:
Dimensions in both versions:
{[f"{dim.name}:{dim.data_type.name}" for dim in common_dimensions]},
Dimensions added:
{[f"{dim.name}:{dim.data_type.name}" for dim in added_dimensions]}
Metrics:
{[f"{dim.name}:{dim.data_type.name}" for dim in metrics]},
Colums that could not be classified:
{[f"{dim.name}:{dim.data_type.name}" for dim in undetermined_columns]}.""",
fg="yellow",
)
)
# Generate query from template.
env = Environment(loader=FileSystemLoader(str(THIS_PATH)))
query_with_mitigation_template = env.get_template(
f"{QUERY_WITH_MITIGATION_NAME}_template.sql"
)
query_with_mitigation_sql = reformat(
query_with_mitigation_template.render(
new_version_cte=new.query_cte,
new_version=new_query,
new_agg_cte=new_agg.query_cte,
new_agg=new_agg_query,
previous_agg_cte=previous_agg.query_cte,
previous_agg=previous_agg_query,
shredded_cte=shredded.query_cte,
shredded=shredded_query,
final_select=final_select,
)
)
write_sql(
output_dir=query_with_mitigation_path,
full_table_id=new.full_table_id,
basename=f"{QUERY_WITH_MITIGATION_NAME}.sql",
sql=query_with_mitigation_sql,
skip_existing=False,
)
# return Path("sql")
return (
Path("sql")
/ new.project_id
/ new.dataset
/ new.destination_table
/ f"{QUERY_WITH_MITIGATION_NAME}.sql",
query_with_mitigation_sql,
)

Просмотреть файл

@ -22,6 +22,7 @@ from ..backfill.parse import (
Backfill,
BackfillStatus,
)
from ..backfill.shredder_mitigation import generate_query_with_shredder_mitigation
from ..backfill.utils import (
get_backfill_backup_table_name,
get_backfill_file_from_qualified_table_name,
@ -508,10 +509,27 @@ def _initiate_backfill(
custom_query = None
if entry.shredder_mitigation is True:
custom_query = "query_with_shredder_mitigation.sql" # TODO: Replace with "= generate_query_with_shredder_mitigation()"
logging_str += " This backfill uses a query with shredder mitigation."
click.echo(
click.style(
f"Generating query with shredder mitigation for {dataset}.{table}...",
fg="blue",
)
)
query, _ = generate_query_with_shredder_mitigation(
client=bigquery.Client(project=project),
project_id=project,
dataset=dataset,
destination_table=table,
backfill_date=entry.start_date.isoformat(),
)
custom_query = Path(query)
click.echo(
click.style(
f"Starting backfill with custom query: '{custom_query}'.", fg="blue"
)
)
elif entry.custom_query:
custom_query = entry.custom_query
custom_query = Path(entry.custom_query)
# backfill table
# in the long-run we should remove the query backfill command and require a backfill entry for all backfills

Просмотреть файл

@ -675,9 +675,7 @@ def backfill(
sys.exit(1)
if custom_query:
query_files = paths_matching_name_pattern(
name, sql_dir, project_id, [custom_query]
)
query_files = paths_matching_name_pattern(custom_query, sql_dir, project_id)
else:
query_files = paths_matching_name_pattern(name, sql_dir, project_id)

Просмотреть файл

@ -251,47 +251,3 @@ def temp_dataset_option(
help="Dataset where intermediate query results will be temporarily stored, "
"formatted as PROJECT_ID.DATASET_ID",
)
def extract_last_group_by_from_query(
sql_path: Optional[Path] = None, sql_text: Optional[str] = None
):
"""Return the list of columns in the latest group by of a query."""
if not sql_path and not sql_text:
click.ClickException(
"Please provide an sql file or sql text to extract the group by."
)
if sql_path:
try:
query_text = sql_path.read_text()
except (FileNotFoundError, OSError):
click.ClickException(f'Failed to read query from: "{sql_path}."')
else:
query_text = str(sql_text)
group_by_list = []
# Remove single and multi-line comments (/* */), trailing semicolon if present and normalize whitespace.
query_text = re.sub(r"/\*.*?\*/", "", query_text, flags=re.DOTALL)
query_text = re.sub(r"--[^\n]*\n", "\n", query_text)
query_text = re.sub(r"\s+", " ", query_text).strip()
if query_text.endswith(";"):
query_text = query_text[:-1].strip()
last_group_by_original = re.search(
r"(?i){}(?!.*{})".format(re.escape("GROUP BY"), re.escape("GROUP BY")),
query_text,
re.DOTALL,
)
if last_group_by_original:
group_by = query_text[last_group_by_original.end() :].lstrip()
# Remove parenthesis, closing parenthesis, LIMIT, ORDER BY and text after those. Remove also opening parenthesis.
group_by = (
re.sub(r"(?i)[\n\)].*|LIMIT.*|ORDER BY.*", "", group_by)
.replace("(", "")
.strip()
)
if group_by:
group_by_list = group_by.replace(" ", "").replace("\n", "").split(",")
return group_by_list

Просмотреть файл

@ -9,9 +9,10 @@ import string
import tempfile
import warnings
from pathlib import Path
from typing import List, Set, Tuple
from typing import List, Optional, Set, Tuple
from uuid import uuid4
import click
import sqlglot
from google.cloud import bigquery
from jinja2 import Environment, FileSystemLoader
@ -287,6 +288,50 @@ def qualify_table_references_in_file(path: Path) -> str:
return updated_query
def extract_last_group_by_from_query(
sql_path: Optional[Path] = None, sql_text: Optional[str] = None
):
"""Return the list of columns in the latest group by of a query."""
if not sql_path and not sql_text:
raise click.ClickException(
"Missing an sql file or sql text to extract the group by."
)
if sql_path:
try:
query_text = sql_path.read_text()
except (FileNotFoundError, OSError):
raise click.ClickException(f'Failed to read query from: "{sql_path}."')
else:
query_text = str(sql_text)
group_by_list = []
# Remove single and multi-line comments (/* */), trailing semicolon if present and normalize whitespace.
query_text = re.sub(r"/\*.*?\*/", "", query_text, flags=re.DOTALL)
query_text = re.sub(r"--[^\n]*\n", "\n", query_text)
query_text = re.sub(r"\s+", " ", query_text).strip()
if query_text.endswith(";"):
query_text = query_text[:-1].strip()
last_group_by_original = re.search(
r"(?i){}(?!.*{})".format(re.escape("GROUP BY"), re.escape("GROUP BY")),
query_text,
re.DOTALL,
)
if last_group_by_original:
group_by = query_text[last_group_by_original.end() :].lstrip()
# Remove parenthesis, closing parenthesis, LIMIT, ORDER BY and text after those. Remove also opening parenthesis.
group_by = (
re.sub(r"(?i)[\n\)].*|LIMIT.*|ORDER BY.*", "", group_by)
.replace("(", "")
.strip()
)
if group_by:
group_by_list = group_by.replace(" ", "").replace("\n", "").split(",")
return group_by_list
class TempDatasetReference(bigquery.DatasetReference):
"""Extend DatasetReference to simplify generating temporary tables."""

Просмотреть файл

@ -1,14 +1,23 @@
from datetime import datetime, time
import os
from datetime import date, datetime, time
from pathlib import Path
from unittest.mock import call, patch
import click
import pytest
from click.exceptions import ClickException
from click.testing import CliRunner
from bigquery_etl.backfill.shredder_mitigation import (
PREVIOUS_DATE,
QUERY_WITH_MITIGATION_NAME,
Column,
ColumnStatus,
ColumnType,
DataTypeGroup,
Subset,
classify_columns,
generate_query_with_shredder_mitigation,
get_bigquery_type,
)
@ -54,7 +63,7 @@ class TestClassifyColumns(object):
),
Column(
name="first_seen_year",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.DIMENSION,
status=ColumnStatus.COMMON,
),
@ -93,7 +102,7 @@ class TestClassifyColumns(object):
),
Column(
name="metric_numeric",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.METRIC,
status=ColumnStatus.COMMON,
),
@ -135,7 +144,7 @@ class TestClassifyColumns(object):
),
Column(
name="first_seen_year",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.DIMENSION,
status=ColumnStatus.COMMON,
),
@ -167,7 +176,7 @@ class TestClassifyColumns(object):
),
Column(
name="metric_numeric",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.METRIC,
status=ColumnStatus.COMMON,
),
@ -240,7 +249,7 @@ class TestClassifyColumns(object):
),
Column(
name="metric_numeric",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.METRIC,
status=ColumnStatus.COMMON,
),
@ -384,7 +393,7 @@ class TestClassifyColumns(object):
expected_metrics = [
Column(
name="metric_bigint",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.METRIC,
status=ColumnStatus.COMMON,
),
@ -396,7 +405,7 @@ class TestClassifyColumns(object):
),
Column(
name="metric_int",
data_type=DataTypeGroup.NUMERIC,
data_type=DataTypeGroup.INTEGER,
column_type=ColumnType.METRIC,
status=ColumnStatus.COMMON,
),
@ -411,7 +420,7 @@ class TestClassifyColumns(object):
assert metrics == expected_metrics
assert undefined == []
def test_matching_new_row_and_new_columns(self):
def test_not_matching_new_row_and_new_columns(self):
new_row = {
"submission_date": "2024-01-01",
"channel": None,
@ -420,53 +429,54 @@ class TestClassifyColumns(object):
}
existing_columns = ["submission_date", "channel"]
new_columns = ["submission_date", "channel", "os", "is_default_browser"]
expected_exception_text = "Inconsistent parameters. Columns in new dimensions not found in new row: ['is_default_browser']"
expected_exception_text = "Existing dimensions don't match columns retrieved by query. Missing ['is_default_browser']."
with pytest.raises(ClickException) as e:
classify_columns(new_row, existing_columns, new_columns)
assert (str(e.value)) == expected_exception_text
assert (str(e.value.message)) == expected_exception_text
def test_missing_parameters(self):
new_row = {}
expected_exception_text = (
f"Missing required parameters. Received: new_row= {new_row}\n"
f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n"
f"existing_dimension_columns= [],\nnew_dimension_columns= []."
)
with pytest.raises(ClickException) as e:
classify_columns(new_row, [], [])
assert (str(e.value)) == expected_exception_text
assert (str(e.value.message)) == expected_exception_text
new_row = {"column_1": "2024-01-01", "column_2": "Windows"}
new_columns = {"column_2"}
expected_exception_text = (
f"Missing required parameters. Received: new_row= {new_row}\n"
f"existing_dimension_columns= [],\nnew_dimension_columns= []."
f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n"
f"existing_dimension_columns= [],\nnew_dimension_columns= {new_columns}."
)
with pytest.raises(ClickException) as e:
classify_columns(new_row, [], [])
assert (str(e.value)) == expected_exception_text
classify_columns(new_row, [], new_columns)
assert (str(e.value.message)) == expected_exception_text
new_row = {"column_1": "2024-01-01", "column_2": "Windows"}
existing_columns = ["column_1"]
expected_exception_text = (
f"Missing required parameters. Received: new_row= {new_row}\n"
f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n"
f"existing_dimension_columns= {existing_columns},\nnew_dimension_columns= []."
)
with pytest.raises(ClickException) as e:
classify_columns(new_row, existing_columns, [])
assert (str(e.value)) == expected_exception_text
assert (str(e.value.message)) == expected_exception_text
class TestGetBigqueryType(object):
def test_numeric_group(self):
assert get_bigquery_type(3) == DataTypeGroup.NUMERIC
assert get_bigquery_type(2024) == DataTypeGroup.NUMERIC
assert get_bigquery_type(9223372036854775807) == DataTypeGroup.NUMERIC
assert get_bigquery_type(123456) == DataTypeGroup.NUMERIC
assert get_bigquery_type(-123456) == DataTypeGroup.NUMERIC
assert get_bigquery_type(3) == DataTypeGroup.INTEGER
assert get_bigquery_type(2024) == DataTypeGroup.INTEGER
assert get_bigquery_type(9223372036854775807) == DataTypeGroup.INTEGER
assert get_bigquery_type(123456) == DataTypeGroup.INTEGER
assert get_bigquery_type(-123456) == DataTypeGroup.INTEGER
assert get_bigquery_type(789.01) == DataTypeGroup.FLOAT
assert get_bigquery_type(1.00000000000000000000000456) == DataTypeGroup.FLOAT
assert get_bigquery_type(999999999999999999999.999999999) == DataTypeGroup.FLOAT
assert get_bigquery_type(-1.23456) == DataTypeGroup.FLOAT
assert get_bigquery_type(100000000000000000000.123456789) == DataTypeGroup.FLOAT
def test_boolean_group(self):
assert get_bigquery_type(False) == DataTypeGroup.BOOLEAN
@ -474,14 +484,630 @@ class TestGetBigqueryType(object):
def test_date_type(self):
assert get_bigquery_type("2024-01-01") == DataTypeGroup.DATE
assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATE
assert get_bigquery_type("2024-01-01T10:00:00Z") == DataTypeGroup.DATE
assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATE
assert get_bigquery_type("2024-08-01 12:34:56 UTC") == DataTypeGroup.DATE
assert get_bigquery_type("12:34:56") == DataTypeGroup.DATE
assert get_bigquery_type(time(12, 34, 56)) == DataTypeGroup.DATE
assert get_bigquery_type(datetime(2024, 12, 26)) == DataTypeGroup.DATE
assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATETIME
assert get_bigquery_type("2024-01-01T10:00:00Z") == DataTypeGroup.TIMESTAMP
assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATETIME
assert get_bigquery_type("2024-08-01 12:34:56 UTC") == DataTypeGroup.TIMESTAMP
assert get_bigquery_type("2024-09-02 14:30:45") == DataTypeGroup.DATETIME
assert get_bigquery_type("12:34:56") == DataTypeGroup.TIME
assert get_bigquery_type(time(12, 34, 56)) == DataTypeGroup.TIME
assert get_bigquery_type(datetime(2024, 12, 26)) == DataTypeGroup.DATETIME
assert get_bigquery_type(date(2024, 12, 26)) == DataTypeGroup.DATE
def test_other_types(self):
assert get_bigquery_type("2024") == DataTypeGroup.STRING
assert get_bigquery_type(None) == DataTypeGroup.UNDETERMINED
class TestSubset(object):
project_id = "moz-fx-data-shared-prod"
dataset = "test"
destination_table = "test_query_v2"
destination_table_previous = "test_query_v1"
path = Path("sql") / project_id / dataset / destination_table
path_previous = Path("sql") / project_id / dataset / destination_table_previous
@pytest.fixture
def runner(self):
return CliRunner()
@patch("google.cloud.bigquery.Client")
def test_version(self, mock_client):
test_tables_correct = [
("test_v1", 1),
("test_v10", 10),
("test_v0", 0),
]
for table, expected in test_tables_correct:
test_subset = Subset(
mock_client, table, None, self.dataset, self.project_id, None
)
assert test_subset.version == expected
test_tables_incorrect = [
("test_v-19", 1),
("test_v", 1),
("test_3", None),
("test_10", None),
("test_3", 1),
]
for table, expected in test_tables_incorrect:
test_subset = Subset(
mock_client, table, None, self.dataset, self.project_id, None
)
with pytest.raises(click.ClickException) as e:
_ = test_subset.version
assert e.type == click.ClickException
assert (
e.value.message
== f"Invalid or missing table version in {test_subset.destination_table}."
)
@patch("google.cloud.bigquery.Client")
def test_partitioning(self, mock_client, runner):
"""Test that partitioning type and value associated to a subset are returned as expected."""
test_subset = Subset(
mock_client,
self.destination_table,
None,
self.dataset,
self.project_id,
None,
)
with runner.isolated_filesystem():
os.makedirs(Path(self.path), exist_ok=True)
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true"
)
assert test_subset.partitioning == {
"field": "submission_date",
"type": "DAY",
}
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: first_seen_date\n require_partition_filter: true"
)
assert test_subset.partitioning == {
"field": "first_seen_date",
"type": "DAY",
}
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"friendly_name: ABCD\ndescription: ABCD\nlabels:\n incremental: true"
)
assert test_subset.partitioning == {"field": None, "type": None}
@patch("google.cloud.bigquery.Client")
def test_generate_query(self, mock_client):
"""Test cases: aggregate, different columns, different metrics, missing metrics, added columns / metrics"""
test_subset = Subset(
mock_client,
self.destination_table,
None,
self.dataset,
self.project_id,
None,
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"],
from_clause=f"{self.destination_table_previous}",
group_by_clause="ALL",
)
assert test_subset_query == (
f"SELECT column_1 FROM {self.destination_table_previous}" f" GROUP BY ALL"
)
test_subset_query = test_subset.generate_query(
select_list=[1, 2, 3],
from_clause=f"{self.destination_table_previous}",
order_by_clause="1, 2, 3",
)
assert test_subset_query == (
f"SELECT 1, 2, 3 FROM {self.destination_table_previous}"
f" ORDER BY 1, 2, 3"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1", 2, 3],
from_clause=f"{self.destination_table_previous}",
group_by_clause="1, 2, 3",
)
assert test_subset_query == (
f"SELECT column_1, 2, 3 FROM {self.destination_table_previous}"
f" GROUP BY 1, 2, 3"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"], from_clause=f"{self.destination_table_previous}"
)
assert (
test_subset_query
== f"SELECT column_1 FROM {self.destination_table_previous}"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"],
from_clause=f"{self.destination_table_previous}",
where_clause="column_1 IS NOT NULL",
group_by_clause="1",
order_by_clause="1",
)
assert test_subset_query == (
f"SELECT column_1 FROM {self.destination_table_previous}"
f" WHERE column_1 IS NOT NULL GROUP BY 1 ORDER BY 1"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"],
from_clause=f"{self.destination_table_previous}",
group_by_clause="1",
order_by_clause="1",
)
assert test_subset_query == (
f"SELECT column_1 FROM {self.destination_table_previous}"
f" GROUP BY 1 ORDER BY 1"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"],
from_clause=f"{self.destination_table_previous}",
having_clause="column_1 > 1",
)
assert (
test_subset_query
== f"SELECT column_1 FROM {self.destination_table_previous}"
)
test_subset_query = test_subset.generate_query(
select_list=["column_1"],
from_clause=f"{self.destination_table_previous}",
group_by_clause="1",
having_clause="column_1 > 1",
)
assert test_subset_query == (
f"SELECT column_1 FROM {self.destination_table_previous}"
f" GROUP BY 1 HAVING column_1 > 1"
)
with pytest.raises(ClickException) as e:
test_subset.generate_query(
select_list=[],
from_clause=f"{self.destination_table_previous}",
group_by_clause="1",
having_clause="column_1 > 1",
)
assert str(e.value.message) == (
f"Missing required clause to generate query.\n"
f"Actuals: SELECT: [], FROM: {test_subset.full_table_id}"
)
@patch("google.cloud.bigquery.Client")
def test_get_query_path_results(self, mock_client, runner):
"""Test expected results for a mocked BigQuery call."""
test_subset = Subset(
mock_client,
self.destination_table,
None,
self.dataset,
self.project_id,
None,
)
expected = [{"column_1": "1234"}]
with runner.isolated_filesystem():
os.makedirs(self.path, exist_ok=True)
with open(Path(self.path) / "query.sql", "w") as f:
f.write("SELECT column_1 WHERE submission_date = @submission_date")
with pytest.raises(FileNotFoundError) as e:
test_subset.get_query_path_results(None)
assert "metadata.yaml" in str(e)
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: submission_date"
)
mock_query = mock_client.query
mock_query.return_value.result.return_value = iter(expected)
result = test_subset.get_query_path_results(None)
assert result == expected
def test_generate_check_with_previous_version(self):
assert True
class TestGenerateQueryWithShredderMitigation(object):
"""Test the function that generates the query for the backfill."""
project_id = "moz-fx-data-shared-prod"
dataset = "test"
destination_table = "test_query_v2"
destination_table_previous = "test_query_v1"
path = Path("sql") / project_id / dataset / destination_table
path_previous = Path("sql") / project_id / dataset / destination_table_previous
@pytest.fixture
def runner(self):
return CliRunner()
@patch("google.cloud.bigquery.Client")
@patch("bigquery_etl.backfill.shredder_mitigation.classify_columns")
def test_generate_query_as_expected(
self, mock_classify_columns, mock_client, runner
):
"""Test that query is generated as expected given a set of mock dimensions and metrics."""
expected = (
Path("sql")
/ self.project_id
/ self.dataset
/ self.destination_table
/ f"{QUERY_WITH_MITIGATION_NAME}.sql",
"""-- Query generated using a template for shredder mitigation.
WITH new_version AS (
SELECT
column_1,
column_2,
metric_1
FROM
upstream_1
GROUP BY
column_1,
column_2
),
new_agg AS (
SELECT
submission_date,
COALESCE(column_1, '??') AS column_1,
SUM(metric_1) AS metric_1
FROM
new_version
GROUP BY
ALL
),
previous_agg AS (
SELECT
submission_date,
COALESCE(column_1, '??') AS column_1,
SUM(metric_1) AS metric_1
FROM
`moz-fx-data-shared-prod.test.test_query_v1`
WHERE
submission_date = @submission_date
GROUP BY
ALL
),
shredded AS (
SELECT
previous_agg.submission_date,
previous_agg.column_1,
CAST(NULL AS STRING) AS column_2,
previous_agg.metric_1 - IFNULL(new_agg.metric_1, 0) AS metric_1
FROM
previous_agg
LEFT JOIN
new_agg
ON previous_agg.submission_date = new_agg.submission_date
AND previous_agg.column_1 = new_agg.column_1
WHERE
previous_agg.metric_1 > IFNULL(new_agg.metric_1, 0)
)
SELECT
column_1,
column_2,
metric_1
FROM
new_version
UNION ALL
SELECT
column_1,
column_2,
metric_1
FROM
shredded;""",
)
with runner.isolated_filesystem():
os.makedirs(self.path, exist_ok=True)
os.makedirs(self.path_previous, exist_ok=True)
with open(
Path("sql")
/ self.project_id
/ self.dataset
/ self.destination_table
/ "query.sql",
"w",
) as f:
f.write(
"SELECT column_1, column_2, metric_1 FROM upstream_1 GROUP BY column_1, column_2"
)
with open(
Path("sql")
/ self.project_id
/ self.dataset
/ self.destination_table_previous
/ "query.sql",
"w",
) as f:
f.write("SELECT column_1, metric_1 FROM upstream_1 GROUP BY column_1")
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true"
)
with open(Path(self.path_previous) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true"
)
mock_classify_columns.return_value = (
[
Column(
"column_1",
DataTypeGroup.STRING,
ColumnType.DIMENSION,
ColumnStatus.COMMON,
)
],
[
Column(
"column_2",
DataTypeGroup.STRING,
ColumnType.DIMENSION,
ColumnStatus.ADDED,
)
],
[],
[
Column(
"metric_1",
DataTypeGroup.INTEGER,
ColumnType.METRIC,
ColumnStatus.COMMON,
)
],
[],
)
with patch.object(
Subset,
"get_query_path_results",
return_value=[{"column_1": "ABC", "column_2": "DEF", "metric_1": 10.0}],
):
assert os.path.isfile(self.path / "query.sql")
assert os.path.isfile(self.path_previous / "query.sql")
result = generate_query_with_shredder_mitigation(
client=mock_client,
project_id=self.project_id,
dataset=self.dataset,
destination_table=self.destination_table,
backfill_date=PREVIOUS_DATE,
)
assert result[0] == expected[0]
assert result[1] == expected[1].replace(" ", "")
@patch("google.cloud.bigquery.Client")
def test_missing_previous_version(self, mock_client, runner):
"""Test that the process raises an exception when previous query version is missing."""
expected_exc = "Missing an sql file or sql text to extract the group by."
with runner.isolated_filesystem():
path = f"sql/{self.project_id}/{self.dataset}/{self.destination_table}"
os.makedirs(path, exist_ok=True)
with open(Path(path) / "query.sql", "w") as f:
f.write("SELECT column_1, column_2 FROM upstream_1 GROUP BY column_1")
with pytest.raises(ClickException) as e:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=self.project_id,
dataset=self.dataset,
destination_table=self.destination_table,
backfill_date=PREVIOUS_DATE,
)
assert (str(e.value.message)) == expected_exc
assert (e.type) == ClickException
@patch("google.cloud.bigquery.Client")
def test_invalid_group_by(self, mock_client, runner):
"""Test that the process raises an exception when the GROUP BY is invalid for any query."""
expected_exc = (
"GROUP BY must use an explicit list of columns. "
"Avoid expressions like `GROUP BY ALL` or `GROUP BY 1, 2, 3`."
)
# client = bigquery.Client()
project_id = "moz-fx-data-shared-prod"
dataset = "test"
destination_table = "test_query_v2"
destination_table_previous = "test_query_v1"
# GROUP BY including a number
with runner.isolated_filesystem():
previous_group_by = "column_1, column_2, column_3"
new_group_by = "3, column_4, column_5"
path = f"sql/{project_id}/{dataset}/{destination_table}"
path_previous = f"sql/{project_id}/{dataset}/{destination_table_previous}"
os.makedirs(path, exist_ok=True)
os.makedirs(path_previous, exist_ok=True)
with open(Path(path) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}"
)
with open(Path(path_previous) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}"
)
with pytest.raises(ClickException) as e:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=project_id,
dataset=dataset,
destination_table=destination_table,
backfill_date=PREVIOUS_DATE,
)
assert (str(e.value.message)) == expected_exc
# GROUP BY 1, 2, 3
previous_group_by = "1, 2, 3"
new_group_by = "column_1, column_2, column_3"
with open(Path(path) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}"
)
with open(Path(path_previous) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}"
)
with pytest.raises(ClickException) as e:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=project_id,
dataset=dataset,
destination_table=destination_table,
backfill_date=PREVIOUS_DATE,
)
assert (str(e.value.message)) == expected_exc
# GROUP BY ALL
previous_group_by = "column_1, column_2, column_3"
new_group_by = "ALL"
with open(Path(path) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}"
)
with open(Path(path_previous) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}"
)
with pytest.raises(ClickException) as e:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=project_id,
dataset=dataset,
destination_table=destination_table,
backfill_date=PREVIOUS_DATE,
)
assert (str(e.value.message)) == expected_exc
# GROUP BY is missing
previous_group_by = "column_1, column_2, column_3"
with open(Path(path) / "query.sql", "w") as f:
f.write("SELECT column_1, column_2 FROM upstream_1")
with open(Path(path_previous) / "query.sql", "w") as f:
f.write(
f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}"
)
with pytest.raises(ClickException) as e:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=project_id,
dataset=dataset,
destination_table=destination_table,
backfill_date=PREVIOUS_DATE,
)
assert (str(e.value.message)) == expected_exc
@patch("google.cloud.bigquery.Client")
@patch("bigquery_etl.backfill.shredder_mitigation.classify_columns")
def test_generate_query_called_with_correct_parameters(
self, mock_classify_columns, mock_client, runner
):
with runner.isolated_filesystem():
os.makedirs(self.path, exist_ok=True)
os.makedirs(self.path_previous, exist_ok=True)
with open(Path(self.path) / "query.sql", "w") as f:
f.write("SELECT column_1 FROM upstream_1 GROUP BY column_1")
with open(Path(self.path) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true"
)
with open(Path(self.path_previous) / "query.sql", "w") as f:
f.write("SELECT column_1 FROM upstream_1 GROUP BY column_1")
with open(Path(self.path_previous) / "metadata.yaml", "w") as f:
f.write(
"bigquery:\n time_partitioning:\n type: day\n "
"field: submission_date\n require_partition_filter: true"
)
mock_classify_columns.return_value = (
[
Column(
"column_1",
DataTypeGroup.STRING,
ColumnType.DIMENSION,
ColumnStatus.COMMON,
)
],
[
Column(
"column_2",
DataTypeGroup.STRING,
ColumnType.DIMENSION,
ColumnStatus.ADDED,
)
],
[],
[
Column(
"metric_1",
DataTypeGroup.INTEGER,
ColumnType.METRIC,
ColumnStatus.COMMON,
)
],
[],
)
with patch.object(
Subset,
"get_query_path_results",
return_value=[{"column_1": "ABC", "column_2": "DEF", "metric_1": 10.0}],
):
with patch.object(Subset, "generate_query") as mock_generate_query:
generate_query_with_shredder_mitigation(
client=mock_client,
project_id=self.project_id,
dataset=self.dataset,
destination_table=self.destination_table,
backfill_date=PREVIOUS_DATE,
)
assert mock_generate_query.call_count == 3
mock_generate_query.assert_has_calls(
[
call(
select_list=[
"submission_date",
"COALESCE(column_1, '??') AS column_1",
"SUM(metric_1) AS metric_1",
],
from_clause="new_version",
group_by_clause="ALL",
),
call(
select_list=[
"submission_date",
"COALESCE(column_1, '??') AS column_1",
"SUM(metric_1) AS metric_1",
],
from_clause="`moz-fx-data-shared-prod.test.test_query_v1`",
where_clause="submission_date = @submission_date",
group_by_clause="ALL",
),
call(
select_list=[
"previous_agg.submission_date",
"previous_agg.column_1",
"CAST(NULL AS STRING) AS column_2",
"previous_agg.metric_1 - IFNULL(new_agg.metric_1, 0) AS metric_1",
],
from_clause="previous_agg LEFT JOIN new_agg ON previous_agg.submission_date = "
"new_agg.submission_date AND previous_agg.column_1 = new_agg.column_1 ",
where_clause="previous_agg.metric_1 > IFNULL(new_agg.metric_1, 0)",
),
]
)

Просмотреть файл

@ -1,12 +1,9 @@
import os
from pathlib import Path
import pytest
from click.exceptions import BadParameter
from click.testing import CliRunner
from bigquery_etl.cli.utils import (
extract_last_group_by_from_query,
is_authenticated,
is_valid_dir,
is_valid_file,
@ -18,9 +15,6 @@ TEST_DIR = Path(__file__).parent.parent
class TestUtils:
@pytest.fixture
def runner(self):
return CliRunner()
def test_is_valid_dir(self):
with pytest.raises(BadParameter):
@ -71,89 +65,3 @@ class TestUtils:
pattern="telemetry_live.event_v4",
invert=True,
)
def test_extract_last_group_by_from_query_sql(self):
"""Test cases using a sql text."""
assert ["ALL"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY ALL"
)
assert ["1"] == extract_last_group_by_from_query(
sql_text="SELECT column_1, SUM(metric_1) AS metric_1 FROM test_table GROUP BY 1;"
)
assert ["1", "2", "3"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3"
)
assert ["1", "2", "3"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3"
)
assert ["column_1", "column_2"] == extract_last_group_by_from_query(
sql_text="""SELECT column_1, column_2 FROM test_table GROUP BY column_1, column_2 ORDER BY 1 LIMIT 100"""
)
assert [] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table"
)
assert [] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table;"
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY column_1"
)
assert ["column_1", "column_2"] == extract_last_group_by_from_query(
sql_text="SELECT column_1, column_2 FROM test_table GROUP BY (column_1, column_2)"
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1)
SELECT column_1 FROM cte"""
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1),
cte2 AS (SELECT column_1, column2 FROM test_table GROUP BY column_1, column2)
SELECT column_1 FROM cte2 GROUP BY column_1 ORDER BY 1 DESC LIMIT 1;"""
)
assert ["column_3"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1, column3 FROM test_table GROUP BY column_1, column3),
cte3 AS (SELECT column_1, column3 FROM cte1 group by column_3) SELECT column_1 FROM cte3 limit 2;"""
)
assert ["column_2"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1),
'cte2 AS (SELECT column_2 FROM test_table GROUP BY column_2),
cte3 AS (SELECT column_1 FROM cte1 UNION ALL SELECT column2 FROM cte2) SELECT * FROM cte3"""
)
assert ["column_2"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1),
cte2 AS (SELECT column_1 FROM test_table GROUP BY column_2) SELECT * FROM cte2;"""
)
assert ["COLUMN"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN),
cte2 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN) SELECT * FROM cte2;"""
)
assert ["COLUMN"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN),
cte2 AS (SELECT COLUMN FROM test_table group by COLUMN) SELECT * FROM cte2;"""
)
def test_extract_last_group_by_from_query_file(self, runner):
"""Test function and cases using a sql file path."""
with runner.isolated_filesystem():
test_path = (
"sql/moz-fx-data-shared-prod/test_shredder_mitigation/test_query_v1"
)
os.makedirs(test_path)
assert os.path.exists(test_path)
assert "test_shredder_mitigation" in os.listdir(
"sql/moz-fx-data-shared-prod"
)
assert is_valid_dir(None, None, test_path)
sql_path = Path(test_path) / "query.sql"
with open(sql_path, "w") as f:
f.write("SELECT column_1 FROM test_table group by ALL")
assert ["ALL"] == extract_last_group_by_from_query(sql_path=sql_path)
with open(sql_path, "w") as f:
f.write(
"SELECT column_1 FROM test_table GROUP BY (column_1) LIMIT (column_1);"
)
assert ["column_1"] == extract_last_group_by_from_query(sql_path=sql_path)

Просмотреть файл

@ -1,6 +1,12 @@
import pytest
import os
from pathlib import Path
import pytest
from click.testing import CliRunner
from bigquery_etl.cli.utils import is_valid_dir
from bigquery_etl.util.common import (
extract_last_group_by_from_query,
project_dirs,
qualify_table_references_in_file,
render,
@ -8,6 +14,10 @@ from bigquery_etl.util.common import (
class TestUtilCommon:
@pytest.fixture
def runner(self):
return CliRunner()
def test_project_dirs(self):
assert project_dirs("test") == ["sql/test"]
@ -399,3 +409,89 @@ class TestUtilCommon:
actual = qualify_table_references_in_file(query_path)
assert actual == expected
def test_extract_last_group_by_from_query_file(self, runner):
"""Test cases using a sql file path."""
with runner.isolated_filesystem():
test_path = (
"sql/moz-fx-data-shared-prod/test_shredder_mitigation/test_query_v1"
)
os.makedirs(test_path)
assert os.path.exists(test_path)
assert "test_shredder_mitigation" in os.listdir(
"sql/moz-fx-data-shared-prod"
)
assert is_valid_dir(None, None, test_path)
sql_path = Path(test_path) / "query.sql"
with open(sql_path, "w") as f:
f.write("SELECT column_1 FROM test_table group by ALL")
assert ["ALL"] == extract_last_group_by_from_query(sql_path=sql_path)
with open(sql_path, "w") as f:
f.write(
"SELECT column_1 FROM test_table GROUP BY (column_1) LIMIT (column_1);"
)
assert ["column_1"] == extract_last_group_by_from_query(sql_path=sql_path)
def test_extract_last_group_by_from_query_sql(self):
"""Test cases using a sql text."""
assert ["ALL"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY ALL"
)
assert ["1"] == extract_last_group_by_from_query(
sql_text="SELECT column_1, SUM(metric_1) AS metric_1 FROM test_table GROUP BY 1;"
)
assert ["1", "2", "3"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3"
)
assert ["1", "2", "3"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3"
)
assert ["column_1", "column_2"] == extract_last_group_by_from_query(
sql_text="""SELECT column_1, column_2 FROM test_table GROUP BY column_1, column_2 ORDER BY 1 LIMIT 100"""
)
assert [] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table"
)
assert [] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table;"
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="SELECT column_1 FROM test_table GROUP BY column_1"
)
assert ["column_1", "column_2"] == extract_last_group_by_from_query(
sql_text="SELECT column_1, column_2 FROM test_table GROUP BY (column_1, column_2)"
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1)
SELECT column_1 FROM cte"""
)
assert ["column_1"] == extract_last_group_by_from_query(
sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1),
cte2 AS (SELECT column_1, column2 FROM test_table GROUP BY column_1, column2)
SELECT column_1 FROM cte2 GROUP BY column_1 ORDER BY 1 DESC LIMIT 1;"""
)
assert ["column_3"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1, column3 FROM test_table GROUP BY column_1, column3),
cte3 AS (SELECT column_1, column3 FROM cte1 group by column_3) SELECT column_1 FROM cte3 limit 2;"""
)
assert ["column_2"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1),
'cte2 AS (SELECT column_2 FROM test_table GROUP BY column_2),
cte3 AS (SELECT column_1 FROM cte1 UNION ALL SELECT column2 FROM cte2) SELECT * FROM cte3"""
)
assert ["column_2"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1),
cte2 AS (SELECT column_1 FROM test_table GROUP BY column_2) SELECT * FROM cte2;"""
)
assert ["COLUMN"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN),
cte2 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN) SELECT * FROM cte2;"""
)
assert ["COLUMN"] == extract_last_group_by_from_query(
sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN),
cte2 AS (SELECT COLUMN FROM test_table group by COLUMN) SELECT * FROM cte2;"""
)