diff --git a/bigquery_etl/backfill/query_with_shredder_mitigation_template.sql b/bigquery_etl/backfill/query_with_shredder_mitigation_template.sql new file mode 100644 index 0000000000..cdf305516c --- /dev/null +++ b/bigquery_etl/backfill/query_with_shredder_mitigation_template.sql @@ -0,0 +1,23 @@ +-- Query generated using a template for shredder mitigation. +WITH {{ new_version_cte }} AS ( + {{ new_version }} +), +{{ new_agg_cte }} AS ( + {{ new_agg }} +), +{{ previous_agg_cte }} AS ( + {{ previous_agg }} +), +{{ shredded_cte }} AS ( + {{ shredded }} +) +SELECT + {{ final_select }} +FROM + {{ new_version_cte }} +UNION ALL +SELECT + {{ final_select}} +FROM + {{ shredded_cte }} +; diff --git a/bigquery_etl/backfill/shredder_mitigation.py b/bigquery_etl/backfill/shredder_mitigation.py index 31d061c2df..53ff18aeb3 100644 --- a/bigquery_etl/backfill/shredder_mitigation.py +++ b/bigquery_etl/backfill/shredder_mitigation.py @@ -1,14 +1,32 @@ -"""Generate a query to backfill an aggregate with shredder mitigation.""" +"""Generate a query with shredder mitigation.""" -from datetime import date, datetime, time, timedelta +import os +import re +from datetime import date +from datetime import datetime as dt +from datetime import time, timedelta from enum import Enum +from pathlib import Path from types import NoneType +from typing import Any, Optional, Tuple +import attr import click +from dateutil import parser +from gcloud.exceptions import NotFound # type: ignore +from google.cloud import bigquery +from jinja2 import Environment, FileSystemLoader +from bigquery_etl.format_sql.formatter import reformat +from bigquery_etl.metadata.parse_metadata import METADATA_FILE, Metadata +from bigquery_etl.util.common import extract_last_group_by_from_query, write_sql + +PREVIOUS_DATE = (dt.now() - timedelta(days=2)).date() +SUFFIX = dt.now().strftime("%Y%m%d%H%M%S") TEMP_DATASET = "tmp" -SUFFIX = datetime.now().strftime("%Y%m%d%H%M%S") -PREVIOUS_DATE = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d") +THIS_PATH = Path(os.path.dirname(__file__)) +DEFAULT_PROJECT_ID = "moz-fx-data-shared-prod" +QUERY_WITH_MITIGATION_NAME = "query_with_shredder_mitigation" class ColumnType(Enum): @@ -29,97 +47,259 @@ class ColumnStatus(Enum): class DataTypeGroup(Enum): - """Data types in BigQuery. Not including ARRAY and STRUCT as these are not expected in aggregate tables.""" + """Data types in BigQuery. Not supported/expected in aggregates: TIMESTAMP, ARRAY, STRUCT.""" STRING = ("STRING", "BYTES") BOOLEAN = "BOOLEAN" + INTEGER = ("INTEGER", "INT64", "INT", "SMALLINT", "TINYINT", "BYTEINT") NUMERIC = ( - "INTEGER", "NUMERIC", "BIGNUMERIC", - "DECIMAL", - "INT64", - "INT", - "SMALLINT", - "BIGINT", - "TINYINT", - "BYTEINT", ) - FLOAT = ("FLOAT",) - DATE = ("DATE", "DATETIME", "TIME", "TIMESTAMP") + FLOAT = "FLOAT" + DATE = "DATE" + DATETIME = "DATETIME" + TIME = "TIME" + TIMESTAMP = "TIMESTAMP" UNDETERMINED = "None" +@attr.define(eq=True) class Column: """Representation of a column in a query, with relevant details for shredder mitigation.""" - def __init__(self, name, data_type, column_type, status): - """Initialize class with required attributes.""" - self.name = name - self.data_type = data_type - self.column_type = column_type - self.status = status + name: str + data_type: DataTypeGroup = attr.field(default=DataTypeGroup.UNDETERMINED) + column_type: ColumnType = attr.field(default=ColumnType.UNDETERMINED) + status: ColumnStatus = attr.field(default=ColumnStatus.UNDETERMINED) - def __eq__(self, other): - """Return attributes only if the referenced object is of type Column.""" - if not isinstance(other, Column): - return NotImplemented - return ( - self.name == other.name - and self.data_type == other.data_type - and self.column_type == other.column_type - and self.status == other.status + """Validate the type of the attributes.""" + + @data_type.validator + def validate_data_type(self, attribute, value): + """Check that the type of data_type is as expected.""" + if not isinstance(value, DataTypeGroup): + raise ValueError(f"Invalid {value} with type: {type(value)}.") + + @column_type.validator + def validate_column_type(self, attribute, value): + """Check that the type of parameter column_type is as expected.""" + if not isinstance(value, ColumnType): + raise ValueError(f"Invalid data type for: {value}.") + + @status.validator + def validate_status(self, attribute, value): + """Check that the type of parameter column_status is as expected.""" + if not isinstance(value, ColumnStatus): + raise ValueError(f"Invalid data type for: {value}.") + + +@attr.define(eq=True) +class Subset: + """Representation of a subset/CTEs in the query and the actions related to this subset.""" + + client: bigquery.Client + destination_table: str = attr.field(default="") + query_cte: str = attr.field(default="") + dataset: str = attr.field(default=TEMP_DATASET) + project_id: str = attr.field(default=DEFAULT_PROJECT_ID) + expiration_days: Optional[float] = attr.field(default=None) + + @property + def expiration_ms(self) -> Optional[float]: + """Convert partition expiration from days to milliseconds.""" + if self.expiration_days is None: + return None + return int(self.expiration_days * 86_400_000) + + @property + def version(self): + """Return the version of the destination table.""" + match = re.search(r"v(\d+)$", self.destination_table) + try: + version = int(match.group(1)) + if not isinstance(version, int): + raise click.ClickException( + f"{self.destination_table} must end with a positive integer." + ) + return version + except (AttributeError, TypeError): + raise click.ClickException( + f"Invalid or missing table version in {self.destination_table}." + ) + + @property + def full_table_id(self): + """Return the full id of the destination table.""" + return f"{self.project_id}.{self.dataset}.{self.destination_table}" + + @property + def query_path(self): + """Return the full path of the query.sql file associated with the subset.""" + sql_path = ( + Path("sql") + / self.project_id + / self.dataset + / self.destination_table + / "query.sql" ) + if not os.path.isfile(sql_path): + click.echo( + click.style(f"Required file not found: {sql_path}.", fg="yellow") + ) + return None + return sql_path - def __repr__(self): - """Return a string representation of the object.""" - return f"Column(name={self.name}, data_type={self.data_type}, column_type={self.column_type}, status={self.status})" + @property + def partitioning(self): + """Return the partition details of the destination table.""" + metadata = Metadata.from_file( + Path("sql") + / self.project_id + / self.dataset + / self.destination_table + / METADATA_FILE + ) + if metadata.bigquery and metadata.bigquery.time_partitioning: + partitioning = { + "type": metadata.bigquery.time_partitioning.type.name, + "field": metadata.bigquery.time_partitioning.field, + } + else: + partitioning = {"type": None, "field": None} + return partitioning + + def generate_query( + self, + select_list, + from_clause, + where_clause=None, + group_by_clause=None, + order_by_clause=None, + having_clause=None, + ): + """Build query to populate the table.""" + if not select_list or not from_clause: + raise click.ClickException( + f"Missing required clause to generate query.\n" + f"Actuals: SELECT: {select_list}, FROM: {self.full_table_id}" + ) + query = f"SELECT {', '.join(map(str, select_list))}" + query += f" FROM {from_clause}" if from_clause is not None else "" + query += f" WHERE {where_clause}" if where_clause is not None else "" + query += f" GROUP BY {group_by_clause}" if group_by_clause is not None else "" + query += ( + f" HAVING {having_clause}" + if having_clause is not None and group_by_clause is not None + else "" + ) + query += f" ORDER BY {order_by_clause}" if order_by_clause is not None else "" + return query + + def get_query_path_results( + self, + backfill_date: date = PREVIOUS_DATE, + row_limit: Optional[int] = None, + **kwargs, + ) -> list[dict[str, Any]]: + """Run the query in sql_path & return result or number of rows requested.""" + having_clause = None + for key, value in kwargs.items(): + if key.lower() == "having_clause": + having_clause = f"{value}" + + with open(self.query_path, "r") as file: + sql_text = file.read().strip() + + if sql_text.endswith(";"): + sql_text = sql_text[:-1] + if having_clause: + sql_text = f"{sql_text} {having_clause}" + if row_limit: + sql_text = f"{sql_text} LIMIT {row_limit}" + + partition_field = self.partitioning["field"] + partition_type = ( + "DATE" if self.partitioning["type"] == "DAY" else self.partitioning["type"] + ) + parameters = None + if partition_field is not None: + parameters = [ + bigquery.ScalarQueryParameter( + partition_field, partition_type, backfill_date + ), + ] + + try: + query_job = self.client.query( + query=sql_text, + job_config=bigquery.QueryJobConfig( + query_parameters=parameters, + use_legacy_sql=False, + dry_run=False, + use_query_cache=False, + ), + ) + query_results = query_job.result() + except NotFound as e: + raise click.ClickException( + f"Unable to query data for {backfill_date}. Table {self.full_table_id} not found." + ) from e + rows = [dict(row) for row in query_results] + return rows + + def compare_current_and_previous_version( + self, + date_partition_parameter, + ): + """Generate and run a data check to compare existing and backfilled data.""" + return NotImplemented def get_bigquery_type(value) -> DataTypeGroup: - """Return the datatype of a value, grouping similar types.""" - date_formats = [ - ("%Y-%m-%d", date), - ("%Y-%m-%d %H:%M:%S", datetime), - ("%Y-%m-%dT%H:%M:%S", datetime), - ("%Y-%m-%dT%H:%M:%SZ", datetime), - ("%Y-%m-%d %H:%M:%S UTC", datetime), - ("%H:%M:%S", time), - ] - for format, dtype in date_formats: - try: - if dtype == time: - parsed_to_time = datetime.strptime(value, format).time() - if isinstance(parsed_to_time, time): - return DataTypeGroup.DATE - else: - parsed_to_date = datetime.strptime(value, format) - if isinstance(parsed_to_date, dtype): - return DataTypeGroup.DATE - except (ValueError, TypeError): - continue + """Find the datatype of a value, grouping similar types.""" + if isinstance(value, dt): + return DataTypeGroup.DATETIME + try: + if isinstance(dt.strptime(value, "%H:%M:%S").time(), time): + return DataTypeGroup.TIME + except (ValueError, TypeError, AttributeError): + pass + try: + value_parsed = parser.isoparse(value.replace(" UTC", "Z").replace(" ", "T")) + if ( + isinstance(value_parsed, dt) + and value_parsed.time() == time(0, 0) + and isinstance(dt.strptime(value, "%Y-%m-%d"), date) + ): + return DataTypeGroup.DATE + if isinstance(value_parsed, dt) and value_parsed.tzinfo is None: + return DataTypeGroup.DATETIME + if isinstance(value_parsed, dt) and value_parsed.tzinfo is not None: + return DataTypeGroup.TIMESTAMP + except (ValueError, TypeError, AttributeError): + pass if isinstance(value, time): - return DataTypeGroup.DATE + return DataTypeGroup.TIME if isinstance(value, date): return DataTypeGroup.DATE - elif isinstance(value, bool): + if isinstance(value, bool): return DataTypeGroup.BOOLEAN if isinstance(value, int): - return DataTypeGroup.NUMERIC - elif isinstance(value, float): + return DataTypeGroup.INTEGER + if isinstance(value, float): return DataTypeGroup.FLOAT - elif isinstance(value, str) or isinstance(value, bytes): + if isinstance(value, (str, bytes)): return DataTypeGroup.STRING - elif isinstance(value, NoneType): + if isinstance(value, NoneType): return DataTypeGroup.UNDETERMINED - else: - raise ValueError(f"Unsupported data type: {type(value)}") + raise ValueError(f"Unsupported data type: {type(value)}") def classify_columns( new_row: dict, existing_dimension_columns: list, new_dimension_columns: list ) -> tuple[list[Column], list[Column], list[Column], list[Column], list[Column]]: - """Compare the new row with the existing columns and return the list of common, added and removed columns by type.""" + """Compare new row with existing columns & return common, added & removed columns.""" common_dimensions = [] added_dimensions = [] removed_dimensions = [] @@ -128,8 +308,10 @@ def classify_columns( if not new_row or not existing_dimension_columns or not new_dimension_columns: raise click.ClickException( - f"Missing required parameters. Received: new_row= {new_row}\n" - f"existing_dimension_columns= {existing_dimension_columns},\nnew_dimension_columns= {new_dimension_columns}." + f"\n\nMissing one or more required parameters. Received:" + f"\nnew_row= {new_row}" + f"\nexisting_dimension_columns= {existing_dimension_columns}," + f"\nnew_dimension_columns= {new_dimension_columns}." ) missing_dimensions = [ @@ -137,7 +319,8 @@ def classify_columns( ] if not len(missing_dimensions) == 0: raise click.ClickException( - f"Inconsistent parameters. Columns in new dimensions not found in new row: {missing_dimensions}" + f"Existing dimensions don't match columns retrieved by query." + f" Missing {missing_dimensions}." ) for key in existing_dimension_columns: @@ -175,7 +358,9 @@ def classify_columns( key not in existing_dimension_columns and key not in new_dimension_columns and ( - value_type is DataTypeGroup.NUMERIC or value_type is DataTypeGroup.FLOAT + value_type is DataTypeGroup.INTEGER + or value_type is DataTypeGroup.FLOAT + or value_type is DataTypeGroup.NUMERIC ) ): # Columns that are not in the previous or new list of grouping columns are metrics. @@ -212,3 +397,311 @@ def classify_columns( metrics_sorted, undefined_sorted, ) + + +def generate_query_with_shredder_mitigation( + client, project_id, dataset, destination_table, backfill_date=PREVIOUS_DATE +) -> Tuple[Path, str]: + """Generate a query to backfill with shredder mitigation.""" + query_with_mitigation_path = Path("sql") / project_id + + # Find query files and grouping of previous and new queries. + new = Subset(client, destination_table, "new_version", dataset, project_id, None) + if new.version < 2: + raise click.ClickException( + f"The new version of the table is expected >= 2. Actual is {new.version}." + ) + + destination_table_previous_version = ( + f"{destination_table[:-len(str(new.version))]}{new.version-1}" + ) + previous = Subset( + client, + destination_table_previous_version, + "previous", + dataset, + project_id, + None, + ) + new_group_by = extract_last_group_by_from_query(sql_path=new.query_path) + previous_group_by = extract_last_group_by_from_query(sql_path=previous.query_path) + + # Check that previous query exists and GROUP BYs are valid in both queries. + integers_in_group_by = False + for e in previous_group_by + new_group_by: + try: + int(e) + integers_in_group_by = True + except ValueError: + continue + if ( + "ALL" in previous_group_by + or "ALL" in new_group_by + or not all(isinstance(e, str) for e in previous_group_by) + or not all(isinstance(e, str) for e in new_group_by) + or not previous_group_by + or not new_group_by + or integers_in_group_by + ): + raise click.ClickException( + "GROUP BY must use an explicit list of columns. " + "Avoid expressions like `GROUP BY ALL` or `GROUP BY 1, 2, 3`." + ) + + # Identify columns common to both queries and columns new. This excludes removed columns. + sample_rows = new.get_query_path_results( + backfill_date=backfill_date, + row_limit=1, + having_clause=f"HAVING {' IS NOT NULL AND '.join(new_group_by)} IS NOT NULL", + ) + if not sample_rows: + sample_rows = new.get_query_path_results( + backfill_date=backfill_date, row_limit=1 + ) + try: + new_table_row = sample_rows[0] + ( + common_dimensions, + added_dimensions, + removed_dimensions, + metrics, + undetermined_columns, + ) = classify_columns(new_table_row, previous_group_by, new_group_by) + except TypeError as e: + raise click.ClickException( + f"Table {destination_table} did not return any rows for {backfill_date}.\n{e}" + ) + + if not common_dimensions or not added_dimensions or not metrics: + raise click.ClickException( + "The process requires that previous & query have at least one dimension in common," + " one dimension added and one metric." + ) + + # Get the new query. + with open(new.query_path, "r") as file: + new_query = file.read().strip() + + # Aggregate previous data and new query results using common dimensions. + new_agg = Subset( + client, destination_table, "new_agg", TEMP_DATASET, project_id, None + ) + previous_agg = Subset( + client, + destination_table_previous_version, + "previous_agg", + TEMP_DATASET, + project_id, + None, + ) + common_select = ( + [previous.partitioning["field"]] + + [ + f"COALESCE({dim.name}, '??') AS {dim.name}" + for dim in common_dimensions + if ( + dim.name != previous.partitioning["field"] + and dim.data_type == DataTypeGroup.STRING + ) + ] + + [ + f"COALESCE({dim.name}, -999) AS {dim.name}" + for dim in common_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type in (DataTypeGroup.INTEGER, DataTypeGroup.FLOAT) + ) + ] + + [ + dim.name + for dim in common_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE) + ) + ] + + [f"SUM({metric.name}) AS {metric.name}" for metric in metrics] + ) + new_agg_query = new_agg.generate_query( + select_list=common_select, + from_clause=f"{new.query_cte}", + group_by_clause="ALL", + ) + previous_agg_query = previous_agg.generate_query( + select_list=common_select, + from_clause=f"`{previous.full_table_id}`", + where_clause=f"{previous.partitioning['field']} = @{previous.partitioning['field']}", + group_by_clause="ALL", + ) + + # Calculate shredder impact. + shredded = Subset( + client, destination_table, "shredded", TEMP_DATASET, project_id, None + ) + + # Set values to NULL for the supported types. + shredded_select = ( + [f"{previous_agg.query_cte}.{new.partitioning['field']}"] + + [ + f"{previous_agg.query_cte}.{dim.name}" + for dim in common_dimensions + if (dim.name != new.partitioning["field"]) + ] + + [ + f"CAST(NULL AS {DataTypeGroup.STRING.name}) AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type == DataTypeGroup.STRING + ) + ] + + [ + f"CAST(NULL AS {DataTypeGroup.BOOLEAN.name}) AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type == DataTypeGroup.BOOLEAN + ) + ] + # This doesn't convert data or dtypes, it's only used to cast NULLs for UNION queries. + + [ + f"CAST(NULL AS {DataTypeGroup.DATE.name}) AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type == DataTypeGroup.DATE + ) + ] + + [ + f"CAST(NULL AS {DataTypeGroup.INTEGER.name}) AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type == DataTypeGroup.INTEGER + ) + ] + + [ + f"CAST(NULL AS {DataTypeGroup.FLOAT.name}) AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type == DataTypeGroup.FLOAT + ) + ] + + [ + f"NULL AS {dim.name}" + for dim in added_dimensions + if ( + dim.name != new.partitioning["field"] + and dim.data_type + not in ( + DataTypeGroup.STRING, + DataTypeGroup.BOOLEAN, + DataTypeGroup.INTEGER, + DataTypeGroup.FLOAT, + ) + ) + ] + + [ + f"{previous_agg.query_cte}.{metric.name} - IFNULL({new_agg.query_cte}.{metric.name}, 0)" + f" AS {metric.name}" + for metric in metrics + if metric.data_type != DataTypeGroup.FLOAT + ] + + [ + f"ROUND({previous_agg.query_cte}.{metric.name}, 3) - " + f"ROUND(IFNULL({new_agg.query_cte}.{metric.name}, 0), 3) AS {metric.name}" + for metric in metrics + if metric.data_type == DataTypeGroup.FLOAT + ] + ) + + shredded_join = " AND ".join( + [ + f"{previous_agg.query_cte}.{previous.partitioning['field']} =" + f" {new_agg.query_cte}.{new.partitioning['field']}" + ] + + [ + f"{previous_agg.query_cte}.{dim.name} = {new_agg.query_cte}.{dim.name}" + for dim in common_dimensions + if dim.name != previous.partitioning["field"] + and dim.data_type not in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE) + ] + + [ + f"({previous_agg.query_cte}.{dim.name} = {new_agg.query_cte}.{dim.name} OR" + f" ({previous_agg.query_cte}.{dim.name} IS NULL" + f" AND {new_agg.query_cte}.{dim.name} IS NULL))" # Compare null values. + for dim in common_dimensions + if dim.name != previous.partitioning["field"] + and dim.data_type in (DataTypeGroup.BOOLEAN, DataTypeGroup.DATE) + ] + ) + shredded_query = shredded.generate_query( + select_list=shredded_select, + from_clause=f"{previous_agg.query_cte} LEFT JOIN {new_agg.query_cte} ON {shredded_join} ", + where_clause=" OR ".join( + [ + f"{previous_agg.query_cte}.{metric.name} > IFNULL({new_agg.query_cte}.{metric.name}, 0)" + for metric in metrics + ] + ), + ) + + combined_list = ( + [dim.name for dim in common_dimensions] + + [dim.name for dim in added_dimensions] + + [metric.name for metric in metrics] + ) + final_select = f"{', '.join(combined_list)}" + + click.echo( + click.style( + f"""Generating query with shredder mitigation and the following columns: + Dimensions in both versions: + {[f"{dim.name}:{dim.data_type.name}" for dim in common_dimensions]}, + Dimensions added: + {[f"{dim.name}:{dim.data_type.name}" for dim in added_dimensions]} + Metrics: + {[f"{dim.name}:{dim.data_type.name}" for dim in metrics]}, + Colums that could not be classified: + {[f"{dim.name}:{dim.data_type.name}" for dim in undetermined_columns]}.""", + fg="yellow", + ) + ) + + # Generate query from template. + env = Environment(loader=FileSystemLoader(str(THIS_PATH))) + query_with_mitigation_template = env.get_template( + f"{QUERY_WITH_MITIGATION_NAME}_template.sql" + ) + + query_with_mitigation_sql = reformat( + query_with_mitigation_template.render( + new_version_cte=new.query_cte, + new_version=new_query, + new_agg_cte=new_agg.query_cte, + new_agg=new_agg_query, + previous_agg_cte=previous_agg.query_cte, + previous_agg=previous_agg_query, + shredded_cte=shredded.query_cte, + shredded=shredded_query, + final_select=final_select, + ) + ) + write_sql( + output_dir=query_with_mitigation_path, + full_table_id=new.full_table_id, + basename=f"{QUERY_WITH_MITIGATION_NAME}.sql", + sql=query_with_mitigation_sql, + skip_existing=False, + ) + + # return Path("sql") + return ( + Path("sql") + / new.project_id + / new.dataset + / new.destination_table + / f"{QUERY_WITH_MITIGATION_NAME}.sql", + query_with_mitigation_sql, + ) diff --git a/bigquery_etl/cli/backfill.py b/bigquery_etl/cli/backfill.py index 01c11dce40..c4776efc80 100644 --- a/bigquery_etl/cli/backfill.py +++ b/bigquery_etl/cli/backfill.py @@ -22,6 +22,7 @@ from ..backfill.parse import ( Backfill, BackfillStatus, ) +from ..backfill.shredder_mitigation import generate_query_with_shredder_mitigation from ..backfill.utils import ( get_backfill_backup_table_name, get_backfill_file_from_qualified_table_name, @@ -508,10 +509,27 @@ def _initiate_backfill( custom_query = None if entry.shredder_mitigation is True: - custom_query = "query_with_shredder_mitigation.sql" # TODO: Replace with "= generate_query_with_shredder_mitigation()" - logging_str += " This backfill uses a query with shredder mitigation." + click.echo( + click.style( + f"Generating query with shredder mitigation for {dataset}.{table}...", + fg="blue", + ) + ) + query, _ = generate_query_with_shredder_mitigation( + client=bigquery.Client(project=project), + project_id=project, + dataset=dataset, + destination_table=table, + backfill_date=entry.start_date.isoformat(), + ) + custom_query = Path(query) + click.echo( + click.style( + f"Starting backfill with custom query: '{custom_query}'.", fg="blue" + ) + ) elif entry.custom_query: - custom_query = entry.custom_query + custom_query = Path(entry.custom_query) # backfill table # in the long-run we should remove the query backfill command and require a backfill entry for all backfills diff --git a/bigquery_etl/cli/query.py b/bigquery_etl/cli/query.py index 82364ba50e..ed163375ca 100644 --- a/bigquery_etl/cli/query.py +++ b/bigquery_etl/cli/query.py @@ -675,9 +675,7 @@ def backfill( sys.exit(1) if custom_query: - query_files = paths_matching_name_pattern( - name, sql_dir, project_id, [custom_query] - ) + query_files = paths_matching_name_pattern(custom_query, sql_dir, project_id) else: query_files = paths_matching_name_pattern(name, sql_dir, project_id) diff --git a/bigquery_etl/cli/utils.py b/bigquery_etl/cli/utils.py index 4794ad819f..6786bcfeaf 100644 --- a/bigquery_etl/cli/utils.py +++ b/bigquery_etl/cli/utils.py @@ -251,47 +251,3 @@ def temp_dataset_option( help="Dataset where intermediate query results will be temporarily stored, " "formatted as PROJECT_ID.DATASET_ID", ) - - -def extract_last_group_by_from_query( - sql_path: Optional[Path] = None, sql_text: Optional[str] = None -): - """Return the list of columns in the latest group by of a query.""" - if not sql_path and not sql_text: - click.ClickException( - "Please provide an sql file or sql text to extract the group by." - ) - - if sql_path: - try: - query_text = sql_path.read_text() - except (FileNotFoundError, OSError): - click.ClickException(f'Failed to read query from: "{sql_path}."') - else: - query_text = str(sql_text) - group_by_list = [] - - # Remove single and multi-line comments (/* */), trailing semicolon if present and normalize whitespace. - query_text = re.sub(r"/\*.*?\*/", "", query_text, flags=re.DOTALL) - query_text = re.sub(r"--[^\n]*\n", "\n", query_text) - query_text = re.sub(r"\s+", " ", query_text).strip() - if query_text.endswith(";"): - query_text = query_text[:-1].strip() - - last_group_by_original = re.search( - r"(?i){}(?!.*{})".format(re.escape("GROUP BY"), re.escape("GROUP BY")), - query_text, - re.DOTALL, - ) - - if last_group_by_original: - group_by = query_text[last_group_by_original.end() :].lstrip() - # Remove parenthesis, closing parenthesis, LIMIT, ORDER BY and text after those. Remove also opening parenthesis. - group_by = ( - re.sub(r"(?i)[\n\)].*|LIMIT.*|ORDER BY.*", "", group_by) - .replace("(", "") - .strip() - ) - if group_by: - group_by_list = group_by.replace(" ", "").replace("\n", "").split(",") - return group_by_list diff --git a/bigquery_etl/util/common.py b/bigquery_etl/util/common.py index 9ba98969da..d0a3323ddb 100644 --- a/bigquery_etl/util/common.py +++ b/bigquery_etl/util/common.py @@ -9,9 +9,10 @@ import string import tempfile import warnings from pathlib import Path -from typing import List, Set, Tuple +from typing import List, Optional, Set, Tuple from uuid import uuid4 +import click import sqlglot from google.cloud import bigquery from jinja2 import Environment, FileSystemLoader @@ -287,6 +288,50 @@ def qualify_table_references_in_file(path: Path) -> str: return updated_query +def extract_last_group_by_from_query( + sql_path: Optional[Path] = None, sql_text: Optional[str] = None +): + """Return the list of columns in the latest group by of a query.""" + if not sql_path and not sql_text: + raise click.ClickException( + "Missing an sql file or sql text to extract the group by." + ) + + if sql_path: + try: + query_text = sql_path.read_text() + except (FileNotFoundError, OSError): + raise click.ClickException(f'Failed to read query from: "{sql_path}."') + else: + query_text = str(sql_text) + group_by_list = [] + + # Remove single and multi-line comments (/* */), trailing semicolon if present and normalize whitespace. + query_text = re.sub(r"/\*.*?\*/", "", query_text, flags=re.DOTALL) + query_text = re.sub(r"--[^\n]*\n", "\n", query_text) + query_text = re.sub(r"\s+", " ", query_text).strip() + if query_text.endswith(";"): + query_text = query_text[:-1].strip() + + last_group_by_original = re.search( + r"(?i){}(?!.*{})".format(re.escape("GROUP BY"), re.escape("GROUP BY")), + query_text, + re.DOTALL, + ) + + if last_group_by_original: + group_by = query_text[last_group_by_original.end() :].lstrip() + # Remove parenthesis, closing parenthesis, LIMIT, ORDER BY and text after those. Remove also opening parenthesis. + group_by = ( + re.sub(r"(?i)[\n\)].*|LIMIT.*|ORDER BY.*", "", group_by) + .replace("(", "") + .strip() + ) + if group_by: + group_by_list = group_by.replace(" ", "").replace("\n", "").split(",") + return group_by_list + + class TempDatasetReference(bigquery.DatasetReference): """Extend DatasetReference to simplify generating temporary tables.""" diff --git a/tests/backfill/test_shredder_mitigation.py b/tests/backfill/test_shredder_mitigation.py index 6ba4dea103..1baa555374 100644 --- a/tests/backfill/test_shredder_mitigation.py +++ b/tests/backfill/test_shredder_mitigation.py @@ -1,14 +1,23 @@ -from datetime import datetime, time +import os +from datetime import date, datetime, time +from pathlib import Path +from unittest.mock import call, patch +import click import pytest from click.exceptions import ClickException +from click.testing import CliRunner from bigquery_etl.backfill.shredder_mitigation import ( + PREVIOUS_DATE, + QUERY_WITH_MITIGATION_NAME, Column, ColumnStatus, ColumnType, DataTypeGroup, + Subset, classify_columns, + generate_query_with_shredder_mitigation, get_bigquery_type, ) @@ -54,7 +63,7 @@ class TestClassifyColumns(object): ), Column( name="first_seen_year", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.DIMENSION, status=ColumnStatus.COMMON, ), @@ -93,7 +102,7 @@ class TestClassifyColumns(object): ), Column( name="metric_numeric", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.METRIC, status=ColumnStatus.COMMON, ), @@ -135,7 +144,7 @@ class TestClassifyColumns(object): ), Column( name="first_seen_year", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.DIMENSION, status=ColumnStatus.COMMON, ), @@ -167,7 +176,7 @@ class TestClassifyColumns(object): ), Column( name="metric_numeric", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.METRIC, status=ColumnStatus.COMMON, ), @@ -240,7 +249,7 @@ class TestClassifyColumns(object): ), Column( name="metric_numeric", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.METRIC, status=ColumnStatus.COMMON, ), @@ -384,7 +393,7 @@ class TestClassifyColumns(object): expected_metrics = [ Column( name="metric_bigint", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.METRIC, status=ColumnStatus.COMMON, ), @@ -396,7 +405,7 @@ class TestClassifyColumns(object): ), Column( name="metric_int", - data_type=DataTypeGroup.NUMERIC, + data_type=DataTypeGroup.INTEGER, column_type=ColumnType.METRIC, status=ColumnStatus.COMMON, ), @@ -411,7 +420,7 @@ class TestClassifyColumns(object): assert metrics == expected_metrics assert undefined == [] - def test_matching_new_row_and_new_columns(self): + def test_not_matching_new_row_and_new_columns(self): new_row = { "submission_date": "2024-01-01", "channel": None, @@ -420,53 +429,54 @@ class TestClassifyColumns(object): } existing_columns = ["submission_date", "channel"] new_columns = ["submission_date", "channel", "os", "is_default_browser"] - expected_exception_text = "Inconsistent parameters. Columns in new dimensions not found in new row: ['is_default_browser']" + expected_exception_text = "Existing dimensions don't match columns retrieved by query. Missing ['is_default_browser']." with pytest.raises(ClickException) as e: classify_columns(new_row, existing_columns, new_columns) - assert (str(e.value)) == expected_exception_text + assert (str(e.value.message)) == expected_exception_text def test_missing_parameters(self): new_row = {} expected_exception_text = ( - f"Missing required parameters. Received: new_row= {new_row}\n" + f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n" f"existing_dimension_columns= [],\nnew_dimension_columns= []." ) with pytest.raises(ClickException) as e: classify_columns(new_row, [], []) - assert (str(e.value)) == expected_exception_text + assert (str(e.value.message)) == expected_exception_text new_row = {"column_1": "2024-01-01", "column_2": "Windows"} + new_columns = {"column_2"} expected_exception_text = ( - f"Missing required parameters. Received: new_row= {new_row}\n" - f"existing_dimension_columns= [],\nnew_dimension_columns= []." + f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n" + f"existing_dimension_columns= [],\nnew_dimension_columns= {new_columns}." ) with pytest.raises(ClickException) as e: - classify_columns(new_row, [], []) - assert (str(e.value)) == expected_exception_text + classify_columns(new_row, [], new_columns) + assert (str(e.value.message)) == expected_exception_text new_row = {"column_1": "2024-01-01", "column_2": "Windows"} existing_columns = ["column_1"] expected_exception_text = ( - f"Missing required parameters. Received: new_row= {new_row}\n" + f"\n\nMissing one or more required parameters. Received:\nnew_row= {new_row}\n" f"existing_dimension_columns= {existing_columns},\nnew_dimension_columns= []." ) with pytest.raises(ClickException) as e: classify_columns(new_row, existing_columns, []) - assert (str(e.value)) == expected_exception_text + assert (str(e.value.message)) == expected_exception_text class TestGetBigqueryType(object): def test_numeric_group(self): - assert get_bigquery_type(3) == DataTypeGroup.NUMERIC - assert get_bigquery_type(2024) == DataTypeGroup.NUMERIC - assert get_bigquery_type(9223372036854775807) == DataTypeGroup.NUMERIC - assert get_bigquery_type(123456) == DataTypeGroup.NUMERIC - assert get_bigquery_type(-123456) == DataTypeGroup.NUMERIC + assert get_bigquery_type(3) == DataTypeGroup.INTEGER + assert get_bigquery_type(2024) == DataTypeGroup.INTEGER + assert get_bigquery_type(9223372036854775807) == DataTypeGroup.INTEGER + assert get_bigquery_type(123456) == DataTypeGroup.INTEGER + assert get_bigquery_type(-123456) == DataTypeGroup.INTEGER assert get_bigquery_type(789.01) == DataTypeGroup.FLOAT - assert get_bigquery_type(1.00000000000000000000000456) == DataTypeGroup.FLOAT assert get_bigquery_type(999999999999999999999.999999999) == DataTypeGroup.FLOAT assert get_bigquery_type(-1.23456) == DataTypeGroup.FLOAT + assert get_bigquery_type(100000000000000000000.123456789) == DataTypeGroup.FLOAT def test_boolean_group(self): assert get_bigquery_type(False) == DataTypeGroup.BOOLEAN @@ -474,14 +484,630 @@ class TestGetBigqueryType(object): def test_date_type(self): assert get_bigquery_type("2024-01-01") == DataTypeGroup.DATE - assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATE - assert get_bigquery_type("2024-01-01T10:00:00Z") == DataTypeGroup.DATE - assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATE - assert get_bigquery_type("2024-08-01 12:34:56 UTC") == DataTypeGroup.DATE - assert get_bigquery_type("12:34:56") == DataTypeGroup.DATE - assert get_bigquery_type(time(12, 34, 56)) == DataTypeGroup.DATE - assert get_bigquery_type(datetime(2024, 12, 26)) == DataTypeGroup.DATE + assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATETIME + assert get_bigquery_type("2024-01-01T10:00:00Z") == DataTypeGroup.TIMESTAMP + assert get_bigquery_type("2024-01-01T10:00:00") == DataTypeGroup.DATETIME + assert get_bigquery_type("2024-08-01 12:34:56 UTC") == DataTypeGroup.TIMESTAMP + assert get_bigquery_type("2024-09-02 14:30:45") == DataTypeGroup.DATETIME + assert get_bigquery_type("12:34:56") == DataTypeGroup.TIME + assert get_bigquery_type(time(12, 34, 56)) == DataTypeGroup.TIME + assert get_bigquery_type(datetime(2024, 12, 26)) == DataTypeGroup.DATETIME + assert get_bigquery_type(date(2024, 12, 26)) == DataTypeGroup.DATE def test_other_types(self): assert get_bigquery_type("2024") == DataTypeGroup.STRING assert get_bigquery_type(None) == DataTypeGroup.UNDETERMINED + + +class TestSubset(object): + project_id = "moz-fx-data-shared-prod" + dataset = "test" + destination_table = "test_query_v2" + destination_table_previous = "test_query_v1" + path = Path("sql") / project_id / dataset / destination_table + path_previous = Path("sql") / project_id / dataset / destination_table_previous + + @pytest.fixture + def runner(self): + return CliRunner() + + @patch("google.cloud.bigquery.Client") + def test_version(self, mock_client): + test_tables_correct = [ + ("test_v1", 1), + ("test_v10", 10), + ("test_v0", 0), + ] + for table, expected in test_tables_correct: + test_subset = Subset( + mock_client, table, None, self.dataset, self.project_id, None + ) + assert test_subset.version == expected + + test_tables_incorrect = [ + ("test_v-19", 1), + ("test_v", 1), + ("test_3", None), + ("test_10", None), + ("test_3", 1), + ] + for table, expected in test_tables_incorrect: + test_subset = Subset( + mock_client, table, None, self.dataset, self.project_id, None + ) + with pytest.raises(click.ClickException) as e: + _ = test_subset.version + assert e.type == click.ClickException + assert ( + e.value.message + == f"Invalid or missing table version in {test_subset.destination_table}." + ) + + @patch("google.cloud.bigquery.Client") + def test_partitioning(self, mock_client, runner): + """Test that partitioning type and value associated to a subset are returned as expected.""" + test_subset = Subset( + mock_client, + self.destination_table, + None, + self.dataset, + self.project_id, + None, + ) + + with runner.isolated_filesystem(): + os.makedirs(Path(self.path), exist_ok=True) + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true" + ) + assert test_subset.partitioning == { + "field": "submission_date", + "type": "DAY", + } + + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: first_seen_date\n require_partition_filter: true" + ) + assert test_subset.partitioning == { + "field": "first_seen_date", + "type": "DAY", + } + + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "friendly_name: ABCD\ndescription: ABCD\nlabels:\n incremental: true" + ) + assert test_subset.partitioning == {"field": None, "type": None} + + @patch("google.cloud.bigquery.Client") + def test_generate_query(self, mock_client): + """Test cases: aggregate, different columns, different metrics, missing metrics, added columns / metrics""" + test_subset = Subset( + mock_client, + self.destination_table, + None, + self.dataset, + self.project_id, + None, + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], + from_clause=f"{self.destination_table_previous}", + group_by_clause="ALL", + ) + assert test_subset_query == ( + f"SELECT column_1 FROM {self.destination_table_previous}" f" GROUP BY ALL" + ) + + test_subset_query = test_subset.generate_query( + select_list=[1, 2, 3], + from_clause=f"{self.destination_table_previous}", + order_by_clause="1, 2, 3", + ) + assert test_subset_query == ( + f"SELECT 1, 2, 3 FROM {self.destination_table_previous}" + f" ORDER BY 1, 2, 3" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1", 2, 3], + from_clause=f"{self.destination_table_previous}", + group_by_clause="1, 2, 3", + ) + assert test_subset_query == ( + f"SELECT column_1, 2, 3 FROM {self.destination_table_previous}" + f" GROUP BY 1, 2, 3" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], from_clause=f"{self.destination_table_previous}" + ) + assert ( + test_subset_query + == f"SELECT column_1 FROM {self.destination_table_previous}" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], + from_clause=f"{self.destination_table_previous}", + where_clause="column_1 IS NOT NULL", + group_by_clause="1", + order_by_clause="1", + ) + assert test_subset_query == ( + f"SELECT column_1 FROM {self.destination_table_previous}" + f" WHERE column_1 IS NOT NULL GROUP BY 1 ORDER BY 1" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], + from_clause=f"{self.destination_table_previous}", + group_by_clause="1", + order_by_clause="1", + ) + assert test_subset_query == ( + f"SELECT column_1 FROM {self.destination_table_previous}" + f" GROUP BY 1 ORDER BY 1" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], + from_clause=f"{self.destination_table_previous}", + having_clause="column_1 > 1", + ) + assert ( + test_subset_query + == f"SELECT column_1 FROM {self.destination_table_previous}" + ) + + test_subset_query = test_subset.generate_query( + select_list=["column_1"], + from_clause=f"{self.destination_table_previous}", + group_by_clause="1", + having_clause="column_1 > 1", + ) + assert test_subset_query == ( + f"SELECT column_1 FROM {self.destination_table_previous}" + f" GROUP BY 1 HAVING column_1 > 1" + ) + + with pytest.raises(ClickException) as e: + test_subset.generate_query( + select_list=[], + from_clause=f"{self.destination_table_previous}", + group_by_clause="1", + having_clause="column_1 > 1", + ) + assert str(e.value.message) == ( + f"Missing required clause to generate query.\n" + f"Actuals: SELECT: [], FROM: {test_subset.full_table_id}" + ) + + @patch("google.cloud.bigquery.Client") + def test_get_query_path_results(self, mock_client, runner): + """Test expected results for a mocked BigQuery call.""" + test_subset = Subset( + mock_client, + self.destination_table, + None, + self.dataset, + self.project_id, + None, + ) + expected = [{"column_1": "1234"}] + + with runner.isolated_filesystem(): + os.makedirs(self.path, exist_ok=True) + with open(Path(self.path) / "query.sql", "w") as f: + f.write("SELECT column_1 WHERE submission_date = @submission_date") + + with pytest.raises(FileNotFoundError) as e: + test_subset.get_query_path_results(None) + assert "metadata.yaml" in str(e) + + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: submission_date" + ) + mock_query = mock_client.query + mock_query.return_value.result.return_value = iter(expected) + result = test_subset.get_query_path_results(None) + assert result == expected + + def test_generate_check_with_previous_version(self): + assert True + + +class TestGenerateQueryWithShredderMitigation(object): + """Test the function that generates the query for the backfill.""" + + project_id = "moz-fx-data-shared-prod" + dataset = "test" + destination_table = "test_query_v2" + destination_table_previous = "test_query_v1" + path = Path("sql") / project_id / dataset / destination_table + path_previous = Path("sql") / project_id / dataset / destination_table_previous + + @pytest.fixture + def runner(self): + return CliRunner() + + @patch("google.cloud.bigquery.Client") + @patch("bigquery_etl.backfill.shredder_mitigation.classify_columns") + def test_generate_query_as_expected( + self, mock_classify_columns, mock_client, runner + ): + """Test that query is generated as expected given a set of mock dimensions and metrics.""" + + expected = ( + Path("sql") + / self.project_id + / self.dataset + / self.destination_table + / f"{QUERY_WITH_MITIGATION_NAME}.sql", + """-- Query generated using a template for shredder mitigation. + WITH new_version AS ( + SELECT + column_1, + column_2, + metric_1 + FROM + upstream_1 + GROUP BY + column_1, + column_2 + ), + new_agg AS ( + SELECT + submission_date, + COALESCE(column_1, '??') AS column_1, + SUM(metric_1) AS metric_1 + FROM + new_version + GROUP BY + ALL + ), + previous_agg AS ( + SELECT + submission_date, + COALESCE(column_1, '??') AS column_1, + SUM(metric_1) AS metric_1 + FROM + `moz-fx-data-shared-prod.test.test_query_v1` + WHERE + submission_date = @submission_date + GROUP BY + ALL + ), + shredded AS ( + SELECT + previous_agg.submission_date, + previous_agg.column_1, + CAST(NULL AS STRING) AS column_2, + previous_agg.metric_1 - IFNULL(new_agg.metric_1, 0) AS metric_1 + FROM + previous_agg + LEFT JOIN + new_agg + ON previous_agg.submission_date = new_agg.submission_date + AND previous_agg.column_1 = new_agg.column_1 + WHERE + previous_agg.metric_1 > IFNULL(new_agg.metric_1, 0) + ) + SELECT + column_1, + column_2, + metric_1 + FROM + new_version + UNION ALL + SELECT + column_1, + column_2, + metric_1 + FROM + shredded;""", + ) + with runner.isolated_filesystem(): + os.makedirs(self.path, exist_ok=True) + os.makedirs(self.path_previous, exist_ok=True) + with open( + Path("sql") + / self.project_id + / self.dataset + / self.destination_table + / "query.sql", + "w", + ) as f: + f.write( + "SELECT column_1, column_2, metric_1 FROM upstream_1 GROUP BY column_1, column_2" + ) + with open( + Path("sql") + / self.project_id + / self.dataset + / self.destination_table_previous + / "query.sql", + "w", + ) as f: + f.write("SELECT column_1, metric_1 FROM upstream_1 GROUP BY column_1") + + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true" + ) + with open(Path(self.path_previous) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true" + ) + + mock_classify_columns.return_value = ( + [ + Column( + "column_1", + DataTypeGroup.STRING, + ColumnType.DIMENSION, + ColumnStatus.COMMON, + ) + ], + [ + Column( + "column_2", + DataTypeGroup.STRING, + ColumnType.DIMENSION, + ColumnStatus.ADDED, + ) + ], + [], + [ + Column( + "metric_1", + DataTypeGroup.INTEGER, + ColumnType.METRIC, + ColumnStatus.COMMON, + ) + ], + [], + ) + + with patch.object( + Subset, + "get_query_path_results", + return_value=[{"column_1": "ABC", "column_2": "DEF", "metric_1": 10.0}], + ): + assert os.path.isfile(self.path / "query.sql") + assert os.path.isfile(self.path_previous / "query.sql") + result = generate_query_with_shredder_mitigation( + client=mock_client, + project_id=self.project_id, + dataset=self.dataset, + destination_table=self.destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert result[0] == expected[0] + assert result[1] == expected[1].replace(" ", "") + + @patch("google.cloud.bigquery.Client") + def test_missing_previous_version(self, mock_client, runner): + """Test that the process raises an exception when previous query version is missing.""" + expected_exc = "Missing an sql file or sql text to extract the group by." + + with runner.isolated_filesystem(): + path = f"sql/{self.project_id}/{self.dataset}/{self.destination_table}" + os.makedirs(path, exist_ok=True) + with open(Path(path) / "query.sql", "w") as f: + f.write("SELECT column_1, column_2 FROM upstream_1 GROUP BY column_1") + + with pytest.raises(ClickException) as e: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=self.project_id, + dataset=self.dataset, + destination_table=self.destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert (str(e.value.message)) == expected_exc + assert (e.type) == ClickException + + @patch("google.cloud.bigquery.Client") + def test_invalid_group_by(self, mock_client, runner): + """Test that the process raises an exception when the GROUP BY is invalid for any query.""" + expected_exc = ( + "GROUP BY must use an explicit list of columns. " + "Avoid expressions like `GROUP BY ALL` or `GROUP BY 1, 2, 3`." + ) + # client = bigquery.Client() + project_id = "moz-fx-data-shared-prod" + dataset = "test" + destination_table = "test_query_v2" + destination_table_previous = "test_query_v1" + + # GROUP BY including a number + with runner.isolated_filesystem(): + previous_group_by = "column_1, column_2, column_3" + new_group_by = "3, column_4, column_5" + path = f"sql/{project_id}/{dataset}/{destination_table}" + path_previous = f"sql/{project_id}/{dataset}/{destination_table_previous}" + os.makedirs(path, exist_ok=True) + os.makedirs(path_previous, exist_ok=True) + with open(Path(path) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}" + ) + with open(Path(path_previous) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}" + ) + + with pytest.raises(ClickException) as e: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=project_id, + dataset=dataset, + destination_table=destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert (str(e.value.message)) == expected_exc + + # GROUP BY 1, 2, 3 + previous_group_by = "1, 2, 3" + new_group_by = "column_1, column_2, column_3" + with open(Path(path) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}" + ) + with open(Path(path_previous) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}" + ) + with pytest.raises(ClickException) as e: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=project_id, + dataset=dataset, + destination_table=destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert (str(e.value.message)) == expected_exc + + # GROUP BY ALL + previous_group_by = "column_1, column_2, column_3" + new_group_by = "ALL" + with open(Path(path) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {new_group_by}" + ) + with open(Path(path_previous) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}" + ) + with pytest.raises(ClickException) as e: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=project_id, + dataset=dataset, + destination_table=destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert (str(e.value.message)) == expected_exc + + # GROUP BY is missing + previous_group_by = "column_1, column_2, column_3" + with open(Path(path) / "query.sql", "w") as f: + f.write("SELECT column_1, column_2 FROM upstream_1") + with open(Path(path_previous) / "query.sql", "w") as f: + f.write( + f"SELECT column_1, column_2 FROM upstream_1 GROUP BY {previous_group_by}" + ) + with pytest.raises(ClickException) as e: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=project_id, + dataset=dataset, + destination_table=destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert (str(e.value.message)) == expected_exc + + @patch("google.cloud.bigquery.Client") + @patch("bigquery_etl.backfill.shredder_mitigation.classify_columns") + def test_generate_query_called_with_correct_parameters( + self, mock_classify_columns, mock_client, runner + ): + with runner.isolated_filesystem(): + os.makedirs(self.path, exist_ok=True) + os.makedirs(self.path_previous, exist_ok=True) + with open(Path(self.path) / "query.sql", "w") as f: + f.write("SELECT column_1 FROM upstream_1 GROUP BY column_1") + with open(Path(self.path) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n field: submission_date\n require_partition_filter: true" + ) + with open(Path(self.path_previous) / "query.sql", "w") as f: + f.write("SELECT column_1 FROM upstream_1 GROUP BY column_1") + with open(Path(self.path_previous) / "metadata.yaml", "w") as f: + f.write( + "bigquery:\n time_partitioning:\n type: day\n " + "field: submission_date\n require_partition_filter: true" + ) + + mock_classify_columns.return_value = ( + [ + Column( + "column_1", + DataTypeGroup.STRING, + ColumnType.DIMENSION, + ColumnStatus.COMMON, + ) + ], + [ + Column( + "column_2", + DataTypeGroup.STRING, + ColumnType.DIMENSION, + ColumnStatus.ADDED, + ) + ], + [], + [ + Column( + "metric_1", + DataTypeGroup.INTEGER, + ColumnType.METRIC, + ColumnStatus.COMMON, + ) + ], + [], + ) + + with patch.object( + Subset, + "get_query_path_results", + return_value=[{"column_1": "ABC", "column_2": "DEF", "metric_1": 10.0}], + ): + with patch.object(Subset, "generate_query") as mock_generate_query: + generate_query_with_shredder_mitigation( + client=mock_client, + project_id=self.project_id, + dataset=self.dataset, + destination_table=self.destination_table, + backfill_date=PREVIOUS_DATE, + ) + assert mock_generate_query.call_count == 3 + mock_generate_query.assert_has_calls( + [ + call( + select_list=[ + "submission_date", + "COALESCE(column_1, '??') AS column_1", + "SUM(metric_1) AS metric_1", + ], + from_clause="new_version", + group_by_clause="ALL", + ), + call( + select_list=[ + "submission_date", + "COALESCE(column_1, '??') AS column_1", + "SUM(metric_1) AS metric_1", + ], + from_clause="`moz-fx-data-shared-prod.test.test_query_v1`", + where_clause="submission_date = @submission_date", + group_by_clause="ALL", + ), + call( + select_list=[ + "previous_agg.submission_date", + "previous_agg.column_1", + "CAST(NULL AS STRING) AS column_2", + "previous_agg.metric_1 - IFNULL(new_agg.metric_1, 0) AS metric_1", + ], + from_clause="previous_agg LEFT JOIN new_agg ON previous_agg.submission_date = " + "new_agg.submission_date AND previous_agg.column_1 = new_agg.column_1 ", + where_clause="previous_agg.metric_1 > IFNULL(new_agg.metric_1, 0)", + ), + ] + ) diff --git a/tests/cli/test_cli_utils.py b/tests/cli/test_cli_utils.py index 850c70a841..494819b50a 100644 --- a/tests/cli/test_cli_utils.py +++ b/tests/cli/test_cli_utils.py @@ -1,12 +1,9 @@ -import os from pathlib import Path import pytest from click.exceptions import BadParameter -from click.testing import CliRunner from bigquery_etl.cli.utils import ( - extract_last_group_by_from_query, is_authenticated, is_valid_dir, is_valid_file, @@ -18,9 +15,6 @@ TEST_DIR = Path(__file__).parent.parent class TestUtils: - @pytest.fixture - def runner(self): - return CliRunner() def test_is_valid_dir(self): with pytest.raises(BadParameter): @@ -71,89 +65,3 @@ class TestUtils: pattern="telemetry_live.event_v4", invert=True, ) - - def test_extract_last_group_by_from_query_sql(self): - """Test cases using a sql text.""" - assert ["ALL"] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table GROUP BY ALL" - ) - assert ["1"] == extract_last_group_by_from_query( - sql_text="SELECT column_1, SUM(metric_1) AS metric_1 FROM test_table GROUP BY 1;" - ) - assert ["1", "2", "3"] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3" - ) - assert ["1", "2", "3"] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3" - ) - assert ["column_1", "column_2"] == extract_last_group_by_from_query( - sql_text="""SELECT column_1, column_2 FROM test_table GROUP BY column_1, column_2 ORDER BY 1 LIMIT 100""" - ) - assert [] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table" - ) - assert [] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table;" - ) - assert ["column_1"] == extract_last_group_by_from_query( - sql_text="SELECT column_1 FROM test_table GROUP BY column_1" - ) - assert ["column_1", "column_2"] == extract_last_group_by_from_query( - sql_text="SELECT column_1, column_2 FROM test_table GROUP BY (column_1, column_2)" - ) - assert ["column_1"] == extract_last_group_by_from_query( - sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1) - SELECT column_1 FROM cte""" - ) - assert ["column_1"] == extract_last_group_by_from_query( - sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1), - cte2 AS (SELECT column_1, column2 FROM test_table GROUP BY column_1, column2) - SELECT column_1 FROM cte2 GROUP BY column_1 ORDER BY 1 DESC LIMIT 1;""" - ) - assert ["column_3"] == extract_last_group_by_from_query( - sql_text="""WITH cte1 AS (SELECT column_1, column3 FROM test_table GROUP BY column_1, column3), - cte3 AS (SELECT column_1, column3 FROM cte1 group by column_3) SELECT column_1 FROM cte3 limit 2;""" - ) - assert ["column_2"] == extract_last_group_by_from_query( - sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1), - 'cte2 AS (SELECT column_2 FROM test_table GROUP BY column_2), - cte3 AS (SELECT column_1 FROM cte1 UNION ALL SELECT column2 FROM cte2) SELECT * FROM cte3""" - ) - assert ["column_2"] == extract_last_group_by_from_query( - sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1), - cte2 AS (SELECT column_1 FROM test_table GROUP BY column_2) SELECT * FROM cte2;""" - ) - - assert ["COLUMN"] == extract_last_group_by_from_query( - sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN), - cte2 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN) SELECT * FROM cte2;""" - ) - - assert ["COLUMN"] == extract_last_group_by_from_query( - sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN), - cte2 AS (SELECT COLUMN FROM test_table group by COLUMN) SELECT * FROM cte2;""" - ) - - def test_extract_last_group_by_from_query_file(self, runner): - """Test function and cases using a sql file path.""" - with runner.isolated_filesystem(): - test_path = ( - "sql/moz-fx-data-shared-prod/test_shredder_mitigation/test_query_v1" - ) - os.makedirs(test_path) - assert os.path.exists(test_path) - assert "test_shredder_mitigation" in os.listdir( - "sql/moz-fx-data-shared-prod" - ) - assert is_valid_dir(None, None, test_path) - - sql_path = Path(test_path) / "query.sql" - with open(sql_path, "w") as f: - f.write("SELECT column_1 FROM test_table group by ALL") - assert ["ALL"] == extract_last_group_by_from_query(sql_path=sql_path) - - with open(sql_path, "w") as f: - f.write( - "SELECT column_1 FROM test_table GROUP BY (column_1) LIMIT (column_1);" - ) - assert ["column_1"] == extract_last_group_by_from_query(sql_path=sql_path) diff --git a/tests/util/test_common.py b/tests/util/test_common.py index 0e433dd402..36b40ec83c 100644 --- a/tests/util/test_common.py +++ b/tests/util/test_common.py @@ -1,6 +1,12 @@ -import pytest +import os +from pathlib import Path +import pytest +from click.testing import CliRunner + +from bigquery_etl.cli.utils import is_valid_dir from bigquery_etl.util.common import ( + extract_last_group_by_from_query, project_dirs, qualify_table_references_in_file, render, @@ -8,6 +14,10 @@ from bigquery_etl.util.common import ( class TestUtilCommon: + @pytest.fixture + def runner(self): + return CliRunner() + def test_project_dirs(self): assert project_dirs("test") == ["sql/test"] @@ -399,3 +409,89 @@ class TestUtilCommon: actual = qualify_table_references_in_file(query_path) assert actual == expected + + def test_extract_last_group_by_from_query_file(self, runner): + """Test cases using a sql file path.""" + with runner.isolated_filesystem(): + test_path = ( + "sql/moz-fx-data-shared-prod/test_shredder_mitigation/test_query_v1" + ) + os.makedirs(test_path) + assert os.path.exists(test_path) + assert "test_shredder_mitigation" in os.listdir( + "sql/moz-fx-data-shared-prod" + ) + assert is_valid_dir(None, None, test_path) + + sql_path = Path(test_path) / "query.sql" + with open(sql_path, "w") as f: + f.write("SELECT column_1 FROM test_table group by ALL") + assert ["ALL"] == extract_last_group_by_from_query(sql_path=sql_path) + + with open(sql_path, "w") as f: + f.write( + "SELECT column_1 FROM test_table GROUP BY (column_1) LIMIT (column_1);" + ) + assert ["column_1"] == extract_last_group_by_from_query(sql_path=sql_path) + + def test_extract_last_group_by_from_query_sql(self): + """Test cases using a sql text.""" + assert ["ALL"] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table GROUP BY ALL" + ) + assert ["1"] == extract_last_group_by_from_query( + sql_text="SELECT column_1, SUM(metric_1) AS metric_1 FROM test_table GROUP BY 1;" + ) + assert ["1", "2", "3"] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3" + ) + assert ["1", "2", "3"] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table GROUP BY 1, 2, 3" + ) + assert ["column_1", "column_2"] == extract_last_group_by_from_query( + sql_text="""SELECT column_1, column_2 FROM test_table GROUP BY column_1, column_2 ORDER BY 1 LIMIT 100""" + ) + assert [] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table" + ) + assert [] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table;" + ) + assert ["column_1"] == extract_last_group_by_from_query( + sql_text="SELECT column_1 FROM test_table GROUP BY column_1" + ) + assert ["column_1", "column_2"] == extract_last_group_by_from_query( + sql_text="SELECT column_1, column_2 FROM test_table GROUP BY (column_1, column_2)" + ) + assert ["column_1"] == extract_last_group_by_from_query( + sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1) + SELECT column_1 FROM cte""" + ) + assert ["column_1"] == extract_last_group_by_from_query( + sql_text="""WITH cte AS (SELECT column_1 FROM test_table GROUP BY column_1), + cte2 AS (SELECT column_1, column2 FROM test_table GROUP BY column_1, column2) + SELECT column_1 FROM cte2 GROUP BY column_1 ORDER BY 1 DESC LIMIT 1;""" + ) + assert ["column_3"] == extract_last_group_by_from_query( + sql_text="""WITH cte1 AS (SELECT column_1, column3 FROM test_table GROUP BY column_1, column3), + cte3 AS (SELECT column_1, column3 FROM cte1 group by column_3) SELECT column_1 FROM cte3 limit 2;""" + ) + assert ["column_2"] == extract_last_group_by_from_query( + sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1), + 'cte2 AS (SELECT column_2 FROM test_table GROUP BY column_2), + cte3 AS (SELECT column_1 FROM cte1 UNION ALL SELECT column2 FROM cte2) SELECT * FROM cte3""" + ) + assert ["column_2"] == extract_last_group_by_from_query( + sql_text="""WITH cte1 AS (SELECT column_1 FROM test_table GROUP BY column_1), + cte2 AS (SELECT column_1 FROM test_table GROUP BY column_2) SELECT * FROM cte2;""" + ) + + assert ["COLUMN"] == extract_last_group_by_from_query( + sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN), + cte2 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN) SELECT * FROM cte2;""" + ) + + assert ["COLUMN"] == extract_last_group_by_from_query( + sql_text="""WITH cte1 AS (SELECT COLUMN FROM test_table GROUP BY COLUMN), + cte2 AS (SELECT COLUMN FROM test_table group by COLUMN) SELECT * FROM cte2;""" + )