Add options to the cli for passing in arguments to postgres db
This commit is contained in:
Родитель
d084522ec6
Коммит
30bfe79f5f
|
@ -23,6 +23,11 @@ def entry_point():
|
||||||
)
|
)
|
||||||
@click.option("--credentials-bucket", type=str, required=False)
|
@click.option("--credentials-bucket", type=str, required=False)
|
||||||
@click.option("--credentials-prefix", type=str, required=False)
|
@click.option("--credentials-prefix", type=str, required=False)
|
||||||
|
@click.option("--postgres-db", type=str, required=False)
|
||||||
|
@click.option("--postgres-user", type=str, required=False)
|
||||||
|
@click.option("--postgres-pass", type=str, required=False)
|
||||||
|
@click.option("--postgres-host", type=str, required=False)
|
||||||
|
@click.option("--postgres-ro-host", type=str, required=False)
|
||||||
@click.option("--num-partitions", type=int, default=10000)
|
@click.option("--num-partitions", type=int, default=10000)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--source",
|
"--source",
|
||||||
|
@ -40,6 +45,11 @@ def run_aggregator(
|
||||||
credentials_protocol,
|
credentials_protocol,
|
||||||
credentials_bucket,
|
credentials_bucket,
|
||||||
credentials_prefix,
|
credentials_prefix,
|
||||||
|
postgres_db,
|
||||||
|
postgres_user,
|
||||||
|
postgres_pass,
|
||||||
|
postgres_host,
|
||||||
|
postgres_ro_host,
|
||||||
num_partitions,
|
num_partitions,
|
||||||
source,
|
source,
|
||||||
project_id,
|
project_id,
|
||||||
|
@ -55,12 +65,22 @@ def run_aggregator(
|
||||||
mapping = {"file": "file", "s3": "s3a", "gcs": "gs"}
|
mapping = {"file": "file", "s3": "s3a", "gcs": "gs"}
|
||||||
return f"{mapping[protocol]}://{bucket}/{prefix}"
|
return f"{mapping[protocol]}://{bucket}/{prefix}"
|
||||||
|
|
||||||
if credentials_bucket and credentials_prefix:
|
# priority of reading credentials is options > credentials file > environment
|
||||||
|
option_credentials = {
|
||||||
|
"POSTGRES_DB": postgres_db,
|
||||||
|
"POSTGRES_USER": postgres_user,
|
||||||
|
"POSTGRES_PASS": postgres_pass,
|
||||||
|
"POSTGRES_HOST": postgres_host,
|
||||||
|
"POSTGRES_RO_HOST": postgres_ro_host,
|
||||||
|
}
|
||||||
|
if all(option_credentials.values()):
|
||||||
|
print("reading credentials from options")
|
||||||
|
environ.update(option_credentials)
|
||||||
|
elif credentials_bucket and credentials_prefix:
|
||||||
path = create_path(credentials_protocol, credentials_bucket, credentials_prefix)
|
path = create_path(credentials_protocol, credentials_bucket, credentials_prefix)
|
||||||
print(f"reading credentials from {path}")
|
print(f"reading credentials from {path}")
|
||||||
creds = spark.read.json(path, multiLine=True).first().asDict()
|
creds = spark.read.json(path, multiLine=True).first().asDict()
|
||||||
for k, v in creds.items():
|
environ.update(creds)
|
||||||
environ[k] = v
|
|
||||||
else:
|
else:
|
||||||
print(f"assuming credentials from the environment")
|
print(f"assuming credentials from the environment")
|
||||||
|
|
||||||
|
|
|
@ -134,16 +134,8 @@ def test_notice_logging_cursor():
|
||||||
lc.check(expected)
|
lc.check(expected)
|
||||||
|
|
||||||
|
|
||||||
def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
@pytest.fixture
|
||||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "access")
|
def mock_dataset(monkeypatch, spark):
|
||||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
|
|
||||||
|
|
||||||
test_creds = str(tmp_path / "creds")
|
|
||||||
# generally points to the production credentials
|
|
||||||
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
|
|
||||||
with open(test_creds, "w") as f:
|
|
||||||
json.dump(creds, f)
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_source(*args, **kwargs):
|
def from_source(*args, **kwargs):
|
||||||
|
@ -157,6 +149,14 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
||||||
|
|
||||||
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
|
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregation_cli(tmp_path, mock_dataset):
|
||||||
|
test_creds = str(tmp_path / "creds")
|
||||||
|
# generally points to the production credentials
|
||||||
|
creds = {"DB_TEST_URL": "dbname=postgres user=postgres host=db"}
|
||||||
|
with open(test_creds, "w") as f:
|
||||||
|
json.dump(creds, f)
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
run_aggregator,
|
run_aggregator,
|
||||||
[
|
[
|
||||||
|
@ -180,20 +180,7 @@ def test_aggregation_cli(tmp_path, monkeypatch, spark):
|
||||||
assert_new_db_functions_backwards_compatible()
|
assert_new_db_functions_backwards_compatible()
|
||||||
|
|
||||||
|
|
||||||
def test_aggregation_cli_no_credentials_file(monkeypatch, spark):
|
def test_aggregation_cli_no_credentials_file(mock_dataset):
|
||||||
class Dataset:
|
|
||||||
@staticmethod
|
|
||||||
def from_source(*args, **kwargs):
|
|
||||||
return Dataset()
|
|
||||||
|
|
||||||
def where(self, *args, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def records(self, *args, **kwargs):
|
|
||||||
return spark.sparkContext.parallelize(generate_pings())
|
|
||||||
|
|
||||||
monkeypatch.setattr("mozaggregator.aggregator.Dataset", Dataset)
|
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
run_aggregator,
|
run_aggregator,
|
||||||
[
|
[
|
||||||
|
@ -219,6 +206,61 @@ def test_aggregation_cli_no_credentials_file(monkeypatch, spark):
|
||||||
assert_new_db_functions_backwards_compatible()
|
assert_new_db_functions_backwards_compatible()
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregation_cli_credentials_option(mock_dataset):
|
||||||
|
empty_env = {
|
||||||
|
"DB_TEST_URL": "",
|
||||||
|
"POSTGRES_DB": "",
|
||||||
|
"POSTGRES_USER": "",
|
||||||
|
"POSTGRES_PASS": "",
|
||||||
|
"POSTGRES_HOST": "",
|
||||||
|
"POSTGRES_RO_HOST": ","
|
||||||
|
}
|
||||||
|
options = [
|
||||||
|
"--postgres-db",
|
||||||
|
"postgres",
|
||||||
|
"--postgres-user",
|
||||||
|
"postgres",
|
||||||
|
"--postgres-pass",
|
||||||
|
"pass",
|
||||||
|
"--postgres-host",
|
||||||
|
"db",
|
||||||
|
"--postgres-ro-host",
|
||||||
|
"db"
|
||||||
|
]
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
run_aggregator,
|
||||||
|
[
|
||||||
|
"--date",
|
||||||
|
SUBMISSION_DATE_1.strftime('%Y%m%d'),
|
||||||
|
"--channels",
|
||||||
|
"nightly,beta",
|
||||||
|
"--num-partitions",
|
||||||
|
10,
|
||||||
|
] + options,
|
||||||
|
env=empty_env,
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0, result.output
|
||||||
|
assert_new_db_functions_backwards_compatible()
|
||||||
|
|
||||||
|
# now test that missing an option will exit with non-zero
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
run_aggregator,
|
||||||
|
[
|
||||||
|
"--date",
|
||||||
|
SUBMISSION_DATE_1.strftime('%Y%m%d'),
|
||||||
|
"--channels",
|
||||||
|
"nightly,beta",
|
||||||
|
"--num-partitions",
|
||||||
|
10,
|
||||||
|
] + options[:2], # missing ro_host
|
||||||
|
env=empty_env,
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
assert result.exit_code == 1
|
||||||
|
|
||||||
|
|
||||||
@runif_bigquery_testing_enabled
|
@runif_bigquery_testing_enabled
|
||||||
def test_aggregation_cli_bigquery(tmp_path, bq_testing_table):
|
def test_aggregation_cli_bigquery(tmp_path, bq_testing_table):
|
||||||
test_creds = str(tmp_path / "creds")
|
test_creds = str(tmp_path / "creds")
|
||||||
|
|
Загрузка…
Ссылка в новой задаче