Add options to the cli for passing in arguments to postgres db
This commit is contained in:
Родитель
d084522ec6
Коммит
30bfe79f5f
|
@ -23,6 +23,11 @@ def entry_point():
|
|||
)
|
||||
@click.option("--credentials-bucket", type=str, required=False)
|
||||
@click.option("--credentials-prefix", type=str, required=False)
|
||||
@click.option("--postgres-db", type=str, required=False)
|
||||
@click.option("--postgres-user", type=str, required=False)
|
||||
@click.option("--postgres-pass", type=str, required=False)
|
||||
@click.option("--postgres-host", type=str, required=False)
|
||||
@click.option("--postgres-ro-host", type=str, required=False)
|
||||
@click.option("--num-partitions", type=int, default=10000)
|
||||
@click.option(
|
||||
"--source",
|
||||
|
@ -40,6 +45,11 @@ def run_aggregator(
|
|||
credentials_protocol,
|
||||
credentials_bucket,
|
||||
credentials_prefix,
|
||||
postgres_db,
|
||||
postgres_user,
|
||||
postgres_pass,
|
||||
postgres_host,
|
||||
postgres_ro_host,
|
||||
num_partitions,
|
||||
source,
|
||||
project_id,
|
||||
|
@ -55,12 +65,22 @@ def run_aggregator(
|
|||
mapping = {"file": "file", "s3": "s3a", "gcs": "gs"}
|
||||
return f"{mapping[protocol]}://{bucket}/{prefix}"
|
||||
|
||||
if credentials_bucket and credentials_prefix:
|
||||
# priority of reading credentials is options > credentials file > environment
|
||||
option_credentials = {
|
||||
"POSTGRES_DB": postgres_db,
|
||||
"POSTGRES_USER": postgres_user,
|
||||
"POSTGRES_PASS": postgres_pass,
|
||||
"POSTGRES_HOST": postgres_host,
|
||||
"POSTGRES_RO_HOST": postgres_ro_host,
|
||||
}
|
||||
if all(option_credentials.values()):
|
||||
print("reading credentials from options")
|
||||
environ.update(option_credentials)
|
||||
elif credentials_bucket and credentials_prefix:
|
||||
path = create_path(credentials_protocol, credentials_bucket, credentials_prefix)
|
||||
print(f"reading credentials from {path}")
|
||||
creds = spark.read.json(path, multiLine=True).first().asDict()
|
||||
for k, v in creds.items():
|
||||
environ[k] = v
|
||||
environ.update(creds)
|
||||
else:
|
||||
print(f"assuming credentials from the environment")
|
||||
|
||||
|
|
|
@ -134,16 +134,8 @@ def test_notice_logging_cursor():
|
|||
lc.check(expected)
|
||||
|
||||
|
||||
def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
|
||||
|
||||
test_creds = str(tmp_path / "creds")
|
||||
# generally points to the production credentials
|
||||
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
|
||||
with open(test_creds, "w") as f:
|
||||
json.dump(creds, f)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(monkeypatch, spark):
|
||||
class Dataset:
|
||||
@staticmethod
|
||||
def from_source(*args, **kwargs):
|
||||
|
@ -157,6 +149,14 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
|||
|
||||
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
|
||||
|
||||
|
||||
def test_aggregation_cli(tmp_path, mock_dataset):
|
||||
test_creds = str(tmp_path / "creds")
|
||||
# generally points to the production credentials
|
||||
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
|
||||
with open(test_creds, "w") as f:
|
||||
json.dump(creds, f)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
run_aggregator,
|
||||
[
|
||||
|
@ -180,20 +180,7 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
|||
assert_new_db_functions_backwards_compatible()
|
||||
|
||||
|
||||
def test_aggregation_cli_no_credentials_file(monkeypatch, spark):
|
||||
class Dataset:
|
||||
@staticmethod
|
||||
def from_source(*args, **kwargs):
|
||||
return Dataset()
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def records(self, *args, **kwargs):
|
||||
return spark.sparkContext.parallelize(generate_pings())
|
||||
|
||||
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
|
||||
|
||||
def test_aggregation_cli_no_credentials_file(mock_dataset):
|
||||
result = CliRunner().invoke(
|
||||
run_aggregator,
|
||||
[
|
||||
|
@ -219,6 +206,61 @@ def test_aggregation_cli_no_credentials_file(monkeypatch, spark):
|
|||
assert_new_db_functions_backwards_compatible()
|
||||
|
||||
|
||||
def test_aggregation_cli_credentials_option(mock_dataset):
|
||||
empty_env = {
|
||||
"DB_TEST_URL": "",
|
||||
"POSTGRES_DB": "",
|
||||
"POSTGRES_USER": "",
|
||||
"POSTGRES_PASS": "",
|
||||
"POSTGRES_HOST": "",
|
||||
"POSTGRES_RO_HOST": ","
|
||||
}
|
||||
options = [
|
||||
"--postgres-db",
|
||||
"postgres",
|
||||
"--postgres-user",
|
||||
"postgres",
|
||||
"--postgres-pass",
|
||||
"pass",
|
||||
"--postgres-host",
|
||||
"db",
|
||||
"--postgres-ro-host",
|
||||
"db"
|
||||
]
|
||||
result = CliRunner().invoke(
|
||||
run_aggregator,
|
||||
[
|
||||
"--date",
|
||||
SUBMISSION_DATE_1.strftime('%Y%m%d'),
|
||||
"--channels",
|
||||
"nightly,beta",
|
||||
"--num-partitions",
|
||||
10,
|
||||
] + options,
|
||||
env=empty_env,
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert_new_db_functions_backwards_compatible()
|
||||
|
||||
# now test that missing an option will exit with non-zero
|
||||
result = CliRunner().invoke(
|
||||
run_aggregator,
|
||||
[
|
||||
"--date",
|
||||
SUBMISSION_DATE_1.strftime('%Y%m%d'),
|
||||
"--channels",
|
||||
"nightly,beta",
|
||||
"--num-partitions",
|
||||
10,
|
||||
] + options[:2], # missing ro_host
|
||||
env=empty_env,
|
||||
catch_exceptions=False,
|
||||
)
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
@runif_bigquery_testing_enabled
|
||||
def test_aggregation_cli_bigquery(tmp_path, bq_testing_table):
|
||||
test_creds = str(tmp_path / "creds")
|
||||
|
|
Загрузка…
Ссылка в новой задаче