From 30bfe79f5f25b8c4fcb43407a9fd130e080f1839 Mon Sep 17 00:00:00 2001 From: Anthony Miyaguchi Date: Fri, 15 Nov 2019 13:47:44 -0800 Subject: [PATCH] Add options to the cli for passing in arguments to postgres db --- mozaggregator/cli.py | 26 +++++++++++-- tests/test_db.py | 90 ++++++++++++++++++++++++++++++++------------ 2 files changed, 89 insertions(+), 27 deletions(-) diff --git a/mozaggregator/cli.py b/mozaggregator/cli.py index b8d230b..f9ce847 100644 --- a/mozaggregator/cli.py +++ b/mozaggregator/cli.py @@ -23,6 +23,11 @@ def entry_point(): ) @click.option("--credentials-bucket", type=str, required=False) @click.option("--credentials-prefix", type=str, required=False) +@click.option("--postgres-db", type=str, required=False) +@click.option("--postgres-user", type=str, required=False) +@click.option("--postgres-pass", type=str, required=False) +@click.option("--postgres-host", type=str, required=False) +@click.option("--postgres-ro-host", type=str, required=False) @click.option("--num-partitions", type=int, default=10000) @click.option( "--source", @@ -40,6 +45,11 @@ def run_aggregator( credentials_protocol, credentials_bucket, credentials_prefix, + postgres_db, + postgres_user, + postgres_pass, + postgres_host, + postgres_ro_host, num_partitions, source, project_id, @@ -55,12 +65,22 @@ def run_aggregator( mapping = {"file": "file", "s3": "s3a", "gcs": "gs"} return f"{mapping[protocol]}://{bucket}/{prefix}" - if credentials_bucket and credentials_prefix: + # priority of reading credentials is options > credentials file > environment + option_credentials = { + "POSTGRES_DB": postgres_db, + "POSTGRES_USER": postgres_user, + "POSTGRES_PASS": postgres_pass, + "POSTGRES_HOST": postgres_host, + "POSTGRES_RO_HOST": postgres_ro_host, + } + if all(option_credentials.values()): + print("reading credentials from options") + environ.update(option_credentials) + elif credentials_bucket and credentials_prefix: path = create_path(credentials_protocol, credentials_bucket, credentials_prefix) print(f"reading credentials from {path}") creds = spark.read.json(path, multiLine=True).first().asDict() - for k, v in creds.items(): - environ[k] = v + environ.update(creds) else: print(f"assuming credentials from the environment") diff --git a/tests/test_db.py b/tests/test_db.py index 554b90e..d09c296 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -134,16 +134,8 @@ def test_notice_logging_cursor(): lc.check(expected) -def test_aggregation_cli(tmp_path, monkeypatch, spark): - monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access") - monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret") - - test_creds = str(tmp_path / "creds") - # generally points to the production credentials - creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"} - with open(test_creds, "w") as f: - json.dump(creds, f) - +@pytest.fixture +def mock_dataset(monkeypatch, spark): class Dataset: @staticmethod def from_source(*args, **kwargs): @@ -157,6 +149,14 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark): monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset) + +def test_aggregation_cli(tmp_path, mock_dataset): + test_creds = str(tmp_path / "creds") + # generally points to the production credentials + creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"} + with open(test_creds, "w") as f: + json.dump(creds, f) + result = CliRunner().invoke( run_aggregator, [ @@ -180,20 +180,7 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark): assert_new_db_functions_backwards_compatible() -def test_aggregation_cli_no_credentials_file(monkeypatch, spark): - class Dataset: - @staticmethod - def from_source(*args, **kwargs): - return Dataset() - - def where(self, *args, **kwargs): - return self - - def records(self, *args, **kwargs): - return spark.sparkContext.parallelize(generate_pings()) - - monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset) - +def test_aggregation_cli_no_credentials_file(mock_dataset): result = CliRunner().invoke( run_aggregator, [ @@ -219,6 +206,61 @@ def test_aggregation_cli_no_credentials_file(monkeypatch, spark): assert_new_db_functions_backwards_compatible() +def test_aggregation_cli_credentials_option(mock_dataset): + empty_env = { + "DB_TEST_URL": "", + "POSTGRES_DB": "", + "POSTGRES_USER": "", + "POSTGRES_PASS": "", + "POSTGRES_HOST": "", + "POSTGRES_RO_HOST": "," + } + options = [ + "--postgres-db", + "postgres", + "--postgres-user", + "postgres", + "--postgres-pass", + "pass", + "--postgres-host", + "db", + "--postgres-ro-host", + "db" + ] + result = CliRunner().invoke( + run_aggregator, + [ + "--date", + SUBMISSION_DATE_1.strftime('%Y%m%d'), + "--channels", + "nightly,beta", + "--num-partitions", + 10, + ] + options, + env=empty_env, + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + assert_new_db_functions_backwards_compatible() + + # now test that missing an option will exit with non-zero + result = CliRunner().invoke( + run_aggregator, + [ + "--date", + SUBMISSION_DATE_1.strftime('%Y%m%d'), + "--channels", + "nightly,beta", + "--num-partitions", + 10, + ] + options[:2], # missing ro_host + env=empty_env, + catch_exceptions=False, + ) + assert result.exit_code == 1 + + @runif_bigquery_testing_enabled def test_aggregation_cli_bigquery(tmp_path, bq_testing_table): test_creds = str(tmp_path / "creds")