Improve list_tables speed for script/copy_deduplicate (#382)

This commit is contained in:
Daniel Thorn 2019-09-26 11:28:22 -07:00 коммит произвёл GitHub
Родитель 835e867514
Коммит 6158817ea3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 88 добавлений и 47 удалений

Просмотреть файл

@ -12,11 +12,13 @@ or to process only a specific list of tables.
from argparse import ArgumentParser
from datetime import datetime, timedelta
from fnmatch import fnmatch
from itertools import groupby
from multiprocessing.pool import ThreadPool
from uuid import uuid4
import fnmatch
import re
from google.cloud import bigquery
QUERY_TEMPLATE = """
@ -192,7 +194,9 @@ def sql_full_table_id(table):
return f"{table.project}.{table.dataset_id}.{table.table_id}"
def get_query_job_configs(client, stable_table, date, dry_run, slices, priority):
def get_query_job_configs(client, live_table, date, dry_run, slices, priority):
sql = QUERY_TEMPLATE.format(live_table=live_table)
stable_table = f"{live_table.replace('_live.', '_stable.', 1)}${date:%Y%m%d}"
kwargs = dict(use_legacy_sql=False, dry_run=dry_run, priority=priority)
start_time = datetime(*date.timetuple()[:6])
end_time = start_time + timedelta(days=1)
@ -203,9 +207,8 @@ def get_query_job_configs(client, stable_table, date, dry_run, slices, priority)
start_time + slice_size * i
for i in range(slices)
] + [end_time] # explicitly use end_time to avoid rounding errors
for i in range(slices):
start, end = params[i:i+2]
yield bigquery.QueryJobConfig(
return [
(client, sql, stable_table, bigquery.QueryJobConfig(
destination=get_temporary_table(
client=client,
schema=stable_table.schema,
@ -218,13 +221,15 @@ def get_query_job_configs(client, stable_table, date, dry_run, slices, priority)
clustering_fields=stable_table.clustering_fields,
time_partitioning=stable_table.time_partitioning,
query_parameters=[
bigquery.ScalarQueryParameter("start_time", "TIMESTAMP", start),
bigquery.ScalarQueryParameter("end_time", "TIMESTAMP", end),
bigquery.ScalarQueryParameter("start_time", "TIMESTAMP", params[i]),
bigquery.ScalarQueryParameter("end_time", "TIMESTAMP", params[i+1]),
],
**kwargs,
)
))
for i in range(slices)
]
else:
yield bigquery.QueryJobConfig(
return [(client, sql, stable_table, bigquery.QueryJobConfig(
destination=stable_table,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
query_parameters=[
@ -232,7 +237,7 @@ def get_query_job_configs(client, stable_table, date, dry_run, slices, priority)
bigquery.ScalarQueryParameter("end_time", "TIMESTAMP", end_time),
],
**kwargs,
)
))]
def run_deduplication_query(client, sql, stable_table, job_config):
@ -266,49 +271,85 @@ def copy_join_parts(client, stable_table, query_jobs):
print(f"Deleted {len(query_jobs)} temporary tables")
def contains_glob(patterns):
return any(set("*?[").intersection(pattern) for pattern in patterns)
def glob_dataset(pattern):
return pattern.split(".", 1)[0]
def compile_glob_patterns(patterns):
return re.compile("|".join(fnmatch.translate(pattern) for pattern in patterns))
def glob_predicate(match, table, arg):
matched = match(table) is not None
if (arg == "only" and not matched) or (arg == "except" and matched):
print(f"Skipping {table} due to --{arg} argument")
return matched
def list_live_tables(client, pool, project_id, except_tables, only_tables):
if only_tables and not contains_glob(only_tables):
# skip list calls when only_tables exists and contains no globs
return [f"{project_id}.{t}" for t in only_tables]
if only_tables and not contains_glob(glob_dataset(t) for t in only_tables):
# skip list_datasets call when only_tables exists and datasets contain no globs
live_datasets = {f"{project_id}.{glob_dataset(t)}" for t in only_tables}
else:
live_datasets = [
d.reference
for d in client.list_datasets(project_id)
if d.dataset_id.endswith("_live")
]
live_tables = [
f"{t.dataset_id}.{t.table_id}"
for tables in pool.map(client.list_tables, live_datasets)
for t in tables
]
if only_tables:
match = compile_glob_patterns(only_tables).match
live_tables = [t for t in live_tables if glob_predicate(match, t, "only")]
if except_tables:
match = compile_glob_patterns(except_tables).match
live_tables = [t for t in live_tables if not glob_predicate(match, t, "except")]
return [f"{project_id}.{t}" for t in live_tables]
def main():
args = parser.parse_args()
client = bigquery.Client()
live_datasets = [
d
for d in client.list_datasets(args.project_id)
if d.dataset_id.endswith("_live")
]
job_args = []
for live_dataset in live_datasets:
stable_dataset_id = live_dataset.dataset_id[:-5] + "_stable"
stable_dataset = client.dataset(stable_dataset_id, args.project_id)
for live_table in client.list_tables(live_dataset.reference):
live_table_id = live_table.table_id
live_table_spec = f"{live_table.dataset_id}.{live_table_id}"
stable_table = stable_dataset.table(f"{live_table_id}${args.date:%Y%m%d}")
if args.except_tables is not None and any(
fnmatch(live_table_spec, pattern) for pattern in args.except_tables
):
print(f"Skipping {live_table_spec} due to --except argument")
continue
if args.only_tables is not None and not any(
fnmatch(live_table_spec, pattern) for pattern in args.only_tables
):
print(f"Skipping {live_table_spec} due to --only argument")
continue
sql = QUERY_TEMPLATE.format(live_table=sql_full_table_id(live_table))
job_args.extend(
(client, sql, stable_table, job_config)
for job_config in get_query_job_configs(
client=client,
stable_table=stable_table,
date=args.date,
dry_run=args.dry_run,
slices=args.slices,
priority=args.priority,
)
)
with ThreadPool(args.parallelism) as pool:
live_tables = list_live_tables(
client=client,
pool=pool,
project_id=args.project_id,
except_tables=args.except_tables,
only_tables=args.only_tables,
)
job_args = [
args
for jobs in pool.starmap(
get_query_job_configs,
[
(
client,
live_table,
args.date,
args.dry_run,
args.slices,
args.priority
)
for live_table in live_tables
]
)
for args in jobs
]
# preserve job_args order so results stay sorted by stable_table for groupby
results = pool.starmap(run_deduplication_query, job_args, chunksize=1)
copy_args = [