Add options to the cli for passing in arguments to postgres db

This commit is contained in:
Anthony Miyaguchi 2019-11-15 13:47:44 -08:00 коммит произвёл Anthony Miyaguchi
Родитель d084522ec6
Коммит 30bfe79f5f
2 изменённых файлов: 89 добавлений и 27 удалений

Просмотреть файл

@ -23,6 +23,11 @@ def entry_point():
) )
@click.option("--credentials-bucket", type=str, required=False) @click.option("--credentials-bucket", type=str, required=False)
@click.option("--credentials-prefix", type=str, required=False) @click.option("--credentials-prefix", type=str, required=False)
@click.option("--postgres-db", type=str, required=False)
@click.option("--postgres-user", type=str, required=False)
@click.option("--postgres-pass", type=str, required=False)
@click.option("--postgres-host", type=str, required=False)
@click.option("--postgres-ro-host", type=str, required=False)
@click.option("--num-partitions", type=int, default=10000) @click.option("--num-partitions", type=int, default=10000)
@click.option( @click.option(
"--source", "--source",
@ -40,6 +45,11 @@ def run_aggregator(
credentials_protocol, credentials_protocol,
credentials_bucket, credentials_bucket,
credentials_prefix, credentials_prefix,
postgres_db,
postgres_user,
postgres_pass,
postgres_host,
postgres_ro_host,
num_partitions, num_partitions,
source, source,
project_id, project_id,
@ -55,12 +65,22 @@ def run_aggregator(
mapping = {"file": "file", "s3": "s3a", "gcs": "gs"} mapping = {"file": "file", "s3": "s3a", "gcs": "gs"}
return f"{mapping[protocol]}://{bucket}/{prefix}" return f"{mapping[protocol]}://{bucket}/{prefix}"
if credentials_bucket and credentials_prefix: # priority of reading credentials is options > credentials file > environment
option_credentials = {
"POSTGRES_DB": postgres_db,
"POSTGRES_USER": postgres_user,
"POSTGRES_PASS": postgres_pass,
"POSTGRES_HOST": postgres_host,
"POSTGRES_RO_HOST": postgres_ro_host,
}
if all(option_credentials.values()):
print("reading credentials from options")
environ.update(option_credentials)
elif credentials_bucket and credentials_prefix:
path = create_path(credentials_protocol, credentials_bucket, credentials_prefix) path = create_path(credentials_protocol, credentials_bucket, credentials_prefix)
print(f"reading credentials from {path}") print(f"reading credentials from {path}")
creds = spark.read.json(path, multiLine=True).first().asDict() creds = spark.read.json(path, multiLine=True).first().asDict()
for k, v in creds.items(): environ.update(creds)
environ[k] = v
else: else:
print(f"assuming credentials from the environment") print(f"assuming credentials from the environment")

Просмотреть файл

@ -134,16 +134,8 @@ def test_notice_logging_cursor():
lc.check(expected) lc.check(expected)
def test_aggregation_cli(tmp_path, monkeypatch, spark): @pytest.fixture
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access") def mock_dataset(monkeypatch, spark):
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
test_creds = str(tmp_path / "creds")
# generally points to the production credentials
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
with open(test_creds, "w") as f:
json.dump(creds, f)
class Dataset: class Dataset:
@staticmethod @staticmethod
def from_source(*args, **kwargs): def from_source(*args, **kwargs):
@ -157,6 +149,14 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset) monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
def test_aggregation_cli(tmp_path, mock_dataset):
test_creds = str(tmp_path / "creds")
# generally points to the production credentials
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
with open(test_creds, "w") as f:
json.dump(creds, f)
result = CliRunner().invoke( result = CliRunner().invoke(
run_aggregator, run_aggregator,
[ [
@ -180,20 +180,7 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
assert_new_db_functions_backwards_compatible() assert_new_db_functions_backwards_compatible()
def test_aggregation_cli_no_credentials_file(monkeypatch, spark): def test_aggregation_cli_no_credentials_file(mock_dataset):
class Dataset:
@staticmethod
def from_source(*args, **kwargs):
return Dataset()
def where(self, *args, **kwargs):
return self
def records(self, *args, **kwargs):
return spark.sparkContext.parallelize(generate_pings())
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
result = CliRunner().invoke( result = CliRunner().invoke(
run_aggregator, run_aggregator,
[ [
@ -219,6 +206,61 @@ def test_aggregation_cli_no_credentials_file(monkeypatch, spark):
assert_new_db_functions_backwards_compatible() assert_new_db_functions_backwards_compatible()
def test_aggregation_cli_credentials_option(mock_dataset):
empty_env = {
"DB_TEST_URL": "",
"POSTGRES_DB": "",
"POSTGRES_USER": "",
"POSTGRES_PASS": "",
"POSTGRES_HOST": "",
"POSTGRES_RO_HOST": ","
}
options = [
"--postgres-db",
"postgres",
"--postgres-user",
"postgres",
"--postgres-pass",
"pass",
"--postgres-host",
"db",
"--postgres-ro-host",
"db"
]
result = CliRunner().invoke(
run_aggregator,
[
"--date",
SUBMISSION_DATE_1.strftime('%Y%m%d'),
"--channels",
"nightly,beta",
"--num-partitions",
10,
] + options,
env=empty_env,
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert_new_db_functions_backwards_compatible()
# now test that missing an option will exit with non-zero
result = CliRunner().invoke(
run_aggregator,
[
"--date",
SUBMISSION_DATE_1.strftime('%Y%m%d'),
"--channels",
"nightly,beta",
"--num-partitions",
10,
] + options[:2], # missing ro_host
env=empty_env,
catch_exceptions=False,
)
assert result.exit_code == 1
@runif_bigquery_testing_enabled @runif_bigquery_testing_enabled
def test_aggregation_cli_bigquery(tmp_path, bq_testing_table): def test_aggregation_cli_bigquery(tmp_path, bq_testing_table):
test_creds = str(tmp_path / "creds") test_creds = str(tmp_path / "creds")