bigquery-etl/bigquery_etl/schema/__init__.py

443 строки
18 KiB
Python

"""Query schema."""
import json
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional
import attr
import yaml
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.cloud.bigquery import SchemaField
from .. import dryrun
SCHEMA_FILE = "schema.yaml"
@attr.s(auto_attribs=True)
class Schema:
"""Query schema representation and helpers."""
schema: Dict[str, Any]
_type_mapping: Dict[str, str] = {
"INT64": "INTEGER",
"BOOL": "BOOLEAN",
"FLOAT64": "FLOAT",
}
@classmethod
def from_query_file(cls, query_file: Path, *args, **kwargs):
"""Create schema from a query file."""
if not query_file.is_file() or query_file.suffix != ".sql":
raise Exception(f"{query_file} is not a valid SQL file.")
schema = dryrun.DryRun(str(query_file), *args, **kwargs).get_schema()
return cls(schema)
@classmethod
def from_schema_file(cls, schema_file: Path):
"""Create schema from a yaml schema file."""
if not schema_file.is_file() or schema_file.suffix != ".yaml":
raise Exception(f"{schema_file} is not a valid YAML schema file.")
with open(schema_file) as file:
schema = yaml.load(file, Loader=yaml.FullLoader)
return cls(schema)
@classmethod
def empty(cls):
"""Create an empty schema."""
return cls({"fields": []})
@classmethod
def from_json(cls, json_schema):
"""Create schema from JSON object."""
return cls(json_schema)
@classmethod
def for_table(cls, project, dataset, table, partitioned_by=None, *args, **kwargs):
"""Get the schema for a BigQuery table."""
query = f"SELECT * FROM `{project}.{dataset}.{table}`"
if partitioned_by:
query += f" WHERE DATE(`{partitioned_by}`) = DATE('2020-01-01')"
try:
return cls(
dryrun.DryRun(
os.path.join(project, dataset, table, "query.sql"),
query,
project=project,
dataset=dataset,
table=table,
*args,
**kwargs,
).get_schema()
)
except Exception as e:
print(f"Cannot get schema for {project}.{dataset}.{table}: {e}")
return cls({"fields": []})
def deploy(self, destination_table: str) -> bigquery.Table:
"""Deploy the schema to BigQuery named after destination_table."""
client = bigquery.Client()
tmp_schema_file = NamedTemporaryFile()
self.to_json_file(Path(tmp_schema_file.name))
bigquery_schema = client.schema_from_json(tmp_schema_file.name)
try:
# destination table already exists, update schema
table = client.get_table(destination_table)
table.schema = bigquery_schema
return client.update_table(table, ["schema"])
except NotFound:
table = bigquery.Table(destination_table, schema=bigquery_schema)
return client.create_table(table)
def merge(
self,
other: "Schema",
exclude: Optional[List[str]] = None,
add_missing_fields=True,
attributes: Optional[List[str]] = None,
ignore_incompatible_fields: bool = False,
ignore_missing_fields: bool = False,
):
"""Merge another schema into the schema."""
if "fields" in other.schema and "fields" in self.schema:
self._traverse(
"root",
self.schema["fields"],
other.schema["fields"],
update=True,
exclude=exclude,
add_missing_fields=add_missing_fields,
attributes=attributes,
ignore_incompatible_fields=ignore_incompatible_fields,
ignore_missing_fields=ignore_missing_fields,
)
def equal(self, other: "Schema") -> bool:
"""Compare to another schema."""
try:
self._traverse(
"root", self.schema["fields"], other.schema["fields"], update=False
)
self._traverse(
"root", other.schema["fields"], self.schema["fields"], update=False
)
except Exception as e:
print(e)
return False
return True
def compatible(self, other: "Schema") -> bool:
"""
Check if schema is compatible with another schema.
If there is a field missing in the schema that is part of the "other" schema,
the schemas are still compatible. However, if there are fields missing in the
"other" schema they are not compatible since, e.g. inserting data into the "other"
schema that follows this schema would fail.
"""
try:
self._traverse(
"root",
self.schema["fields"],
other.schema["fields"],
update=False,
ignore_missing_fields=True,
)
self._traverse(
"root",
other.schema["fields"],
self.schema["fields"],
update=False,
ignore_missing_fields=False,
)
except Exception as e:
print(e)
return False
return True
@staticmethod
def _node_with_mode(node):
"""Add default value for mode to node."""
if "mode" in node:
return node
return {"mode": "NULLABLE", **node}
def _traverse(
self,
prefix,
columns,
other_columns,
update=False,
add_missing_fields=True,
ignore_missing_fields=False,
exclude=None,
attributes=None,
ignore_incompatible_fields=False,
):
"""Traverses two schemas for validation and optionally updates the first schema."""
nodes = {n["name"]: Schema._node_with_mode(n) for n in columns}
other_nodes = {
n["name"]: Schema._node_with_mode(n)
for n in other_columns
if exclude is None or n["name"] not in exclude
}
for node_name, node in other_nodes.items():
field_path = node["name"] + (".[]" if node["mode"] == "REPEATED" else "")
dtype = node["type"]
if node_name in nodes:
# node exists in schema, update attributes where necessary
for node_attr_key, node_attr_value in node.items():
if attributes and node_attr_key not in attributes:
continue
if node_attr_key == "type":
# sometimes types have multiple names (e.g. INT64 and INTEGER)
# make it consistent here
node_attr_value = self._type_mapping.get(
node_attr_value, node_attr_value
)
nodes[node_name][node_attr_key] = self._type_mapping.get(
nodes[node_name][node_attr_key],
nodes[node_name][node_attr_key],
)
if node_attr_key not in nodes[node_name]:
if update:
# add field attributes if not exists in schema
nodes[node_name][node_attr_key] = node_attr_value
# Netlify has a problem starting 2022-03-07 where lots of
# logging slows down builds to the point where our builds hit
# the time limit and fail (bug 1761292), and this print
# statement accounts for 84% of our build logging.
# TODO: Uncomment this print when Netlify fixes the problem.
# print(
# f"Attribute {node_attr_key} added to {prefix}.{field_path}"
# )
else:
if node_attr_key == "description":
print(
"Warning: descriptions for "
f"{prefix}.{field_path} differ"
)
else:
if not ignore_incompatible_fields:
raise Exception(
f"{node_attr_key} missing in {prefix}.{field_path}"
)
elif nodes[node_name][node_attr_key] != node_attr_value:
# check field attribute diffs
if node_attr_key == "description":
# overwrite descripton for the "other" schema
print(
f"Warning: descriptions for {prefix}.{field_path} differ."
)
elif node_attr_key != "fields":
if not ignore_incompatible_fields:
raise Exception(
f"Cannot merge schemas. {node_attr_key} attributes "
f"for {prefix}.{field_path} are incompatible"
)
if dtype == "RECORD" and nodes[node_name]["type"] == "RECORD":
# keep traversing nested fields
self._traverse(
f"{prefix}.{field_path}",
nodes[node_name]["fields"],
node["fields"],
update=update,
add_missing_fields=add_missing_fields,
ignore_missing_fields=ignore_missing_fields,
attributes=attributes,
ignore_incompatible_fields=ignore_incompatible_fields,
)
else:
if update and add_missing_fields:
# node does not exist in schema, add to schema
columns.append(node.copy())
print(f"Field {node_name} added to {prefix}")
else:
if not ignore_missing_fields:
raise Exception(
f"Field {prefix}.{field_path} is missing in schema"
)
def to_yaml_file(self, yaml_path: Path):
"""Write schema to the YAML file path."""
with open(yaml_path, "w") as out:
yaml.dump(self.schema, out, default_flow_style=False, sort_keys=False)
def to_json_file(self, json_path: Path):
"""Write schema to the JSON file path."""
with open(json_path, "w") as out:
json.dump(self.schema["fields"], out, indent=2)
def to_json(self):
"""Return the schema data as JSON."""
return json.dumps(self.schema)
def to_bigquery_schema(self) -> List[SchemaField]:
"""Get the BigQuery representation of the schema."""
return [SchemaField.from_api_repr(field) for field in self.schema["fields"]]
@classmethod
def from_bigquery_schema(cls, fields: List[SchemaField]) -> "Schema":
"""Construct a Schema from the BigQuery representation."""
return cls({"fields": [field.to_api_repr() for field in fields]})
def generate_compatible_select_expression(
self,
target_schema: "Schema",
fields_to_remove: Optional[Iterable[str]] = None,
unnest_structs: bool = False,
max_unnest_depth: int = 0,
unnest_allowlist: Optional[Iterable[str]] = None,
) -> str:
"""Generate the select expression for the source schema based on the target schema.
The output will include all fields of the target schema in the same order of the target.
Any fields that are missing in the source schema are set to NULL.
:param target_schema: The schema to coerce the current schema to.
:param fields_to_remove: Given fields are removed from the output expression. Expressed as a
list of strings with `.` separating each level of nesting, e.g. record_name.field.
:param unnest_structs: If true, all record fields are expressed as structs with all nested
fields explicitly listed. This allows the expression to be compatible even if the
source schemas get new fields added. Otherwise, records are only unnested if they
do not match the target schema.
:param max_unnest_depth: Maximum level of struct nesting to explicitly unnest in
the expression.
:param unnest_allowlist: If set, only the given top-level structs are unnested.
"""
def _type_info(node):
"""Determine the BigQuery type information from Schema object field."""
dtype = node["type"]
if dtype == "RECORD":
dtype = (
"STRUCT<"
+ ", ".join(
f"`{field['name']}` {_type_info(field)}"
for field in node["fields"]
)
+ ">"
)
elif dtype == "FLOAT":
dtype = "FLOAT64"
if node.get("mode") == "REPEATED":
return f"ARRAY<{dtype}>"
return dtype
def recurse_fields(
_source_schema_nodes: List[Dict],
_target_schema_nodes: List[Dict],
path=None,
) -> str:
if path is None:
path = []
select_expr = []
source_schema_nodes = {n["name"]: n for n in _source_schema_nodes}
target_schema_nodes = {n["name"]: n for n in _target_schema_nodes}
# iterate through fields
for node_name, node in target_schema_nodes.items():
dtype = node["type"]
node_path = path + [node_name]
node_path_str = ".".join(node_path)
if node_name in source_schema_nodes: # field exists in app schema
# field matches, can query as-is
if node == source_schema_nodes[node_name] and (
# don't need to unnest scalar
dtype != "RECORD"
or not unnest_structs
# reached max record depth to unnest
or len(node_path) > max_unnest_depth > 0
# field not in unnest allowlist
or (
unnest_allowlist is not None
and node_path[0] not in unnest_allowlist
)
):
if (
fields_to_remove is None
or node_path_str not in fields_to_remove
):
select_expr.append(node_path_str)
elif (
dtype == "RECORD"
): # for nested fields, recursively generate select expression
if (
node.get("mode", None) == "REPEATED"
): # unnest repeated record
select_expr.append(
f"""
ARRAY(
SELECT
STRUCT(
{recurse_fields(
source_schema_nodes[node_name]['fields'],
node['fields'],
[node_name],
)}
)
FROM UNNEST({node_path_str}) AS `{node_name}`
) AS `{node_name}`
"""
)
else: # select struct fields
select_expr.append(
f"""
STRUCT(
{recurse_fields(
source_schema_nodes[node_name]['fields'],
node['fields'],
node_path,
)}
) AS `{node_name}`
"""
)
else: # scalar value doesn't match, e.g. different types
select_expr.append(
f"CAST(NULL AS {_type_info(node)}) AS `{node_name}`"
)
else: # field not found in source schema
select_expr.append(
f"CAST(NULL AS {_type_info(node)}) AS `{node_name}`"
)
return ", ".join(select_expr)
return recurse_fields(
self.schema["fields"],
target_schema.schema["fields"],
)
def generate_select_expression(
self,
remove_fields: Optional[Iterable[str]] = None,
unnest_structs: bool = False,
max_unnest_depth: int = 0,
unnest_allowlist: Optional[Iterable[str]] = None,
) -> str:
"""Generate the select expression for the schema which includes each field."""
return self.generate_compatible_select_expression(
self,
remove_fields,
unnest_structs,
max_unnest_depth,
unnest_allowlist,
)