Added dataproc and gcloud integration
* Added dataproc and bq connector hooks * Added autoscale policy * Adds a cfretl.main runner class to instrument dataproc and RemoteSettings * Added RemoteSettings clone tool for 'cfr' collection to the 'cfr-ml-control' and 'cfr-ml-experiments' collection
This commit is contained in:
Родитель
d9f9084739
Коммит
d8fd9ef374
|
@ -1,10 +1,11 @@
|
|||
# Community Participation Guidelines
|
||||
|
||||
This repository is governed by Mozilla's code of conduct and etiquette guidelines.
|
||||
This repository is governed by Mozilla's code of conduct and etiquette guidelines.
|
||||
For more details, please read the
|
||||
[Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
|
||||
[Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
|
||||
|
||||
## How to Report
|
||||
|
||||
For more information on how to report violations of the Community Participation Guidelines, please read our '[How to Report](https://www.mozilla.org/about/governance/policies/participation/reporting/)' page.
|
||||
|
||||
<!--
|
||||
|
|
|
@ -25,3 +25,5 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||
COPY . /app
|
||||
RUN python setup.py install
|
||||
USER app
|
||||
|
||||
ENTRYPOINT ["/bin/sh"]
|
||||
|
|
24
Makefile
24
Makefile
|
@ -1,6 +1,9 @@
|
|||
.PHONY: build up tests flake8 ci tests-with-cov
|
||||
|
||||
all: auth build run
|
||||
include envfile
|
||||
export $(shell sed 's/=.*//' envfile)
|
||||
|
||||
all: auth import_policy upload build run
|
||||
|
||||
auth:
|
||||
gcloud auth application-default login
|
||||
|
@ -13,6 +16,17 @@ pytest:
|
|||
build:
|
||||
docker build -t cfr-numbermuncher:latest .
|
||||
|
||||
gcloud_tagupload:
|
||||
docker tag cfr-numbermuncher:latest gcr.io/cfr-personalization-experiment/cfr-numbermuncher:latest
|
||||
docker push gcr.io/cfr-personalization-experiment/cfr-numbermuncher:latest
|
||||
|
||||
|
||||
import_policy:
|
||||
gcloud dataproc autoscaling-policies import cfr-personalization-autoscale --region=$(GCLOUD_REGION) --source=./dataproc/autoscale_policy.yaml --verbosity info
|
||||
|
||||
upload:
|
||||
gsutil cp scripts/compute_weights.py gs://cfr-ml-jobs/compute_weights.py
|
||||
|
||||
run:
|
||||
# Create the bot user (not required in prod)
|
||||
docker run -it cfr-numbermuncher:latest bin/install_bot.sh
|
||||
|
@ -21,3 +35,11 @@ run:
|
|||
docker run -v ~/.config:/app/.config \
|
||||
-e GOOGLE_CLOUD_PROJECT=moz-fx-data-derived-datasets \
|
||||
-it cfr-numbermuncher:latest python -m cfretl.main
|
||||
|
||||
|
||||
cluster:
|
||||
gcloud dataproc clusters create cfr-sklearn-cluster3 \
|
||||
--zone=$(GCLOUD_ZONE) \
|
||||
--image-version=preview \
|
||||
--initialization-actions gs://cfr-ml-jobs/actions/python/dataproc_custom.sh
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ login to GCP using your own credentials.
|
|||
|
||||
From the command line - issue: `gcloud auth application-default login`
|
||||
|
||||
|
||||
To build and startup the container - the simplest thing to do is to
|
||||
run
|
||||
|
||||
|
@ -19,29 +18,26 @@ make all
|
|||
|
||||
## Building the container
|
||||
|
||||
A docker file is provided to build the container. You can issue
|
||||
A docker file is provided to build the container. You can issue
|
||||
`make build` to create a local image.
|
||||
|
||||
|
||||
# Kinto authentication
|
||||
|
||||
The container is setup to use a default user with a username/password
|
||||
pair of : (devuser, devpass) against the kinto dev server.
|
||||
|
||||
|
||||
## Building the container
|
||||
|
||||
A standard Dockerfile is provided to build the container - the
|
||||
simplest thing to build the container is to issue: `make build`
|
||||
|
||||
|
||||
## Install the devuser and setup cfr-control, cfr-experiments and cfr-models
|
||||
|
||||
Use `make run` to spin up a testing container.
|
||||
|
||||
This will install initial data into the dev instance of Remote
|
||||
Settings at https://kinto.dev.mozaws.net/v1 and start writing out
|
||||
weight updates. Updates are currently set as a constant of 1 second
|
||||
weight updates. Updates are currently set as a constant of 1 second
|
||||
updates to ease testing.
|
||||
|
||||
The `run` target will automatically mount your GCloud authentication
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
#!/bin/sh
|
||||
SERVER=https://kinto.dev.mozaws.net/v1
|
||||
|
||||
# Delete the bot first to start clean
|
||||
curl -X DELETE ${SERVER}/accounts/devuser -u devuser:devpass
|
||||
|
||||
curl -v -X PUT ${SERVER}/accounts/devuser \
|
||||
-d '{"data": {"password": "devpass"}}' \
|
||||
-H 'Content-Type:application/json'
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
"""
|
||||
This module loads data directly from the telemetry pings table.
|
||||
"""
|
||||
|
||||
from google.cloud import bigquery
|
||||
from datetime import timedelta
|
||||
import datetime
|
||||
import pytz
|
||||
import logging
|
||||
|
||||
LIMIT_CLAUSE = ""
|
||||
BQ_DATE_FORMAT = "%Y-%m-%d %H:00:00"
|
||||
|
||||
|
||||
class InvalidVector(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ASLoader:
|
||||
def __init__(self):
|
||||
# The BigQuery client is lazily bound so that tests run fast
|
||||
self._bqclient = None
|
||||
|
||||
@property
|
||||
def _client(self):
|
||||
if self._bqclient is None:
|
||||
self._bqclient = bigquery.Client()
|
||||
return self._bqclient
|
||||
|
||||
def _build_query(self, dt, limit_rowcount=None):
|
||||
start_dt = datetime.datetime(
|
||||
dt.year, dt.month, dt.day, dt.hour, 0, 0, tzinfo=pytz.utc
|
||||
)
|
||||
end_dt = start_dt + timedelta(hours=1)
|
||||
|
||||
start_ts = start_dt.strftime(BQ_DATE_FORMAT)
|
||||
end_ts = end_dt.strftime(BQ_DATE_FORMAT)
|
||||
|
||||
# production query will be different
|
||||
query_tmpl = (
|
||||
"select * from `moz-fx-data-shar-nonprod-efed`.messaging_system_live.cfr_v1 "
|
||||
"where submission_timestamp > '{start_ts:s}' and submission_timestamp <= '{end_ts:s}' "
|
||||
"limit 10"
|
||||
)
|
||||
|
||||
query = query_tmpl.format(start_ts=start_ts, end_ts=end_ts)
|
||||
return query.strip()
|
||||
|
||||
def _get_pings(self, dt=None, limit_rowcount=None, as_dict=False):
|
||||
if dt is None:
|
||||
logging.warn("No date was specified - defaulting to 7 days ago for an hour")
|
||||
dt = datetime.datetime.now() - timedelta(days=7)
|
||||
|
||||
query = self._build_query(dt, limit_rowcount=limit_rowcount)
|
||||
query_job = self._client.query(
|
||||
query,
|
||||
# Location must match that of the dataset(s) referenced in the query.
|
||||
location="US",
|
||||
) # API request - starts the query
|
||||
|
||||
logging.info("Running : {:s}".format(query))
|
||||
for i in query_job:
|
||||
if as_dict:
|
||||
yield dict([(k, i[k]) for k in i.keys()])
|
||||
else:
|
||||
yield i
|
||||
|
||||
def compute_vector_weights(self):
|
||||
assert len([row for row in self._get_pings()]) > 0
|
||||
# TODO: Call out to dataproc to compute the model here
|
||||
# raise NotImplementedError()
|
|
@ -0,0 +1,270 @@
|
|||
import pkg_resources
|
||||
import time
|
||||
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import dataproc_v1
|
||||
from google.cloud import storage
|
||||
from google.cloud.dataproc_v1.gapic.transports import cluster_controller_grpc_transport
|
||||
from google.cloud.dataproc_v1.gapic.transports import job_controller_grpc_transport
|
||||
|
||||
|
||||
waiting_callback = False
|
||||
|
||||
|
||||
def cluster_callback(operation_future):
|
||||
# Reset global when callback returns.
|
||||
global waiting_callback
|
||||
waiting_callback = False
|
||||
|
||||
|
||||
def wait_for_cluster_create():
|
||||
"""Wait for cluster creation."""
|
||||
print("Waiting for cluster creation...")
|
||||
|
||||
while True:
|
||||
if not waiting_callback:
|
||||
print("Cluster created.")
|
||||
break
|
||||
|
||||
|
||||
def wait_for_cluster_delete():
|
||||
"""Wait for cluster creation."""
|
||||
print("Waiting for cluster deletion...")
|
||||
|
||||
while True:
|
||||
if not waiting_callback:
|
||||
print("Cluster deleted.")
|
||||
break
|
||||
|
||||
|
||||
class DataprocFacade:
|
||||
"""
|
||||
This class exposes a minimal interface to execute PySpark jobs on
|
||||
Dataproc.
|
||||
|
||||
Basic features include:
|
||||
|
||||
* Creating a custom dataproc cluster if one does not exist
|
||||
* Execution of the script and emitting output to a BigQuery table
|
||||
"""
|
||||
|
||||
def __init__(self, project_id, cluster_name, zone):
|
||||
self._project_id = project_id
|
||||
self._cluster_name = cluster_name
|
||||
self._zone = zone
|
||||
self._region = self._get_region_from_zone(zone)
|
||||
|
||||
self._dataproc_cluster_client = None
|
||||
self._dataproc_job_client = None
|
||||
|
||||
def dataproc_job_client(self):
|
||||
"""
|
||||
Lazily obtain a GCP Dataproc JobController client
|
||||
"""
|
||||
if self._dataproc_job_client is None:
|
||||
job_transport = job_controller_grpc_transport.JobControllerGrpcTransport(
|
||||
address="{}-dataproc.googleapis.com:443".format(self._region)
|
||||
)
|
||||
self._dataproc_job_client = dataproc_v1.JobControllerClient(job_transport)
|
||||
return self._dataproc_job_client
|
||||
|
||||
def dataproc_cluster_client(self):
|
||||
"""
|
||||
Lazily create a Dataproc ClusterController client to setup or
|
||||
tear down dataproc clusters
|
||||
"""
|
||||
|
||||
if self._dataproc_cluster_client is None:
|
||||
client_transport = cluster_controller_grpc_transport.ClusterControllerGrpcTransport(
|
||||
address="{}-dataproc.googleapis.com:443".format(self._region)
|
||||
)
|
||||
self._dataproc_cluster_client = dataproc_v1.ClusterControllerClient(
|
||||
client_transport
|
||||
)
|
||||
return self._dataproc_cluster_client
|
||||
|
||||
def delete_cluster_if_exists(self):
|
||||
"""
|
||||
Destroy the Dataproc cluster if it exists
|
||||
"""
|
||||
try:
|
||||
if self.cluster_exists():
|
||||
self._delete_cluster()
|
||||
wait_for_cluster_delete()
|
||||
else:
|
||||
print("Cluster {} already exists.".format(self._cluster_name))
|
||||
except Exception as exc:
|
||||
raise exc
|
||||
|
||||
def create_cluster_if_not_exists(self):
|
||||
"""
|
||||
Create Dataproc cluster if one doesn't exist yet
|
||||
"""
|
||||
try:
|
||||
if not self.cluster_exists():
|
||||
self._create_cluster()
|
||||
wait_for_cluster_create()
|
||||
else:
|
||||
print("Cluster {} already exists.".format(self._cluster_name))
|
||||
except Exception as exc:
|
||||
raise exc
|
||||
|
||||
def cluster_exists(self):
|
||||
"""
|
||||
Check that the Dataproc Cluster exists
|
||||
"""
|
||||
try:
|
||||
return self.dataproc_cluster_client().get_cluster(
|
||||
self._project_id, self._region, self._cluster_name
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def check_jobexists(self, bucket_name, filename):
|
||||
"""
|
||||
Check that a pyspark script exists in Google Cloud Storage
|
||||
"""
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(filename)
|
||||
return blob.exists()
|
||||
|
||||
def upload_sparkjob(self, bucket_name, src_name):
|
||||
"""
|
||||
Uploads a PySpark file to GCS bucket for later executino
|
||||
"""
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.get_bucket(bucket_name)
|
||||
blob = bucket.blob(src_name)
|
||||
blob.upload_from_filename(
|
||||
pkg_resources.resource_filename("cfretl", "scripts/{}".format(src_name))
|
||||
)
|
||||
print("Uploaded {} to {}".format(src_name, src_name))
|
||||
|
||||
def run_job(self, bucket_name, spark_filename):
|
||||
"""
|
||||
Execute a PySpark job from a GCS bucket in our Dataproc cluster
|
||||
"""
|
||||
job_id = self._submit_pyspark_job(bucket_name, spark_filename)
|
||||
self._wait_for_job(job_id)
|
||||
|
||||
def copy_bq_table(self, dataset_name, src_tbl, dst_tbl):
|
||||
"""
|
||||
Copy a BigQuery table in this project
|
||||
"""
|
||||
client = bigquery.Client()
|
||||
dataset = client.dataset(dataset_name)
|
||||
src = dataset.table(src_tbl)
|
||||
dst = dataset.table(dst_tbl)
|
||||
src_name = ".".join([src._project, src._dataset_id, src.table_id])
|
||||
dst_name = ".".join([dst._project, dst._dataset_id, dst.table_id])
|
||||
copy_job = client.copy_table(src_name, dst_name)
|
||||
copy_job.done()
|
||||
|
||||
# Internal Methods
|
||||
#
|
||||
|
||||
def _delete_cluster(self):
|
||||
cluster = self.dataproc_cluster_client().delete_cluster(
|
||||
self._project_id, self._region, self._cluster_name
|
||||
)
|
||||
cluster.add_done_callback(cluster_callback)
|
||||
global waiting_callback
|
||||
waiting_callback = True
|
||||
|
||||
def _get_region_from_zone(self, zone):
|
||||
try:
|
||||
region_as_list = zone.split("-")[:-1]
|
||||
return "-".join(region_as_list)
|
||||
except (AttributeError, IndexError, ValueError):
|
||||
raise ValueError("Invalid zone provided, please check your input.")
|
||||
|
||||
def _create_cluster(self):
|
||||
"""Create the cluster."""
|
||||
|
||||
# TODO: pass in the bucket somehow. maybe an attribute
|
||||
# settings
|
||||
# bucket_name = "cfr-ml-jobs"
|
||||
|
||||
print("Creating cluster...")
|
||||
cluster_config = {
|
||||
"cluster_name": self._cluster_name,
|
||||
"project_id": self._project_id,
|
||||
"config": {
|
||||
"master_config": {
|
||||
"num_instances": 1,
|
||||
"machine_type_uri": "n1-standard-1",
|
||||
"disk_config": {
|
||||
"boot_disk_type": "pd-ssd",
|
||||
"num_local_ssds": 1,
|
||||
"boot_disk_size_gb": 1000,
|
||||
},
|
||||
},
|
||||
"worker_config": {
|
||||
"num_instances": 2,
|
||||
"machine_type_uri": "n1-standard-8",
|
||||
"disk_config": {
|
||||
"boot_disk_type": "pd-ssd",
|
||||
"num_local_ssds": 1,
|
||||
"boot_disk_size_gb": 1000,
|
||||
},
|
||||
},
|
||||
"autoscaling_config": {
|
||||
"policy_uri": "projects/cfr-personalization-experiment/regions/{region:s}/autoscalingPolicies/cfr-personalization-autoscale".format(
|
||||
region=self._region
|
||||
)
|
||||
},
|
||||
"initialization_actions": [
|
||||
{
|
||||
"executable_file": "gs://cfr-ml-jobs/actions/python/dataproc_custom.sh"
|
||||
}
|
||||
],
|
||||
"software_config": {"image_version": "1.4.16-ubuntu18"},
|
||||
},
|
||||
}
|
||||
|
||||
cluster = self.dataproc_cluster_client().create_cluster(
|
||||
self._project_id, self._region, cluster_config
|
||||
)
|
||||
cluster.add_done_callback(cluster_callback)
|
||||
global waiting_callback
|
||||
waiting_callback = True
|
||||
|
||||
def _submit_pyspark_job(self, bucket_name, filename):
|
||||
"""Submit the Pyspark job to the cluster (assumes `filename` was uploaded
|
||||
to `bucket_name.
|
||||
|
||||
Note that the bigquery connector is added at job submission
|
||||
time and not cluster creation time.
|
||||
"""
|
||||
job_details = {
|
||||
"placement": {"cluster_name": self._cluster_name},
|
||||
"pyspark_job": {
|
||||
"main_python_file_uri": "gs://{}/{}".format(bucket_name, filename),
|
||||
"jar_file_uris": ["gs://spark-lib/bigquery/spark-bigquery-latest.jar"],
|
||||
},
|
||||
}
|
||||
|
||||
result = self.dataproc_job_client().submit_job(
|
||||
project_id=self._project_id, region=self._region, job=job_details
|
||||
)
|
||||
job_id = result.reference.job_id
|
||||
print("Submitted job ID {}.".format(job_id))
|
||||
return job_id
|
||||
|
||||
def _wait_for_job(self, job_id):
|
||||
"""Wait for job to complete or error out."""
|
||||
print("Waiting for job to finish...")
|
||||
while True:
|
||||
job = self.dataproc_job_client().get_job(
|
||||
self._project_id, self._region, job_id
|
||||
)
|
||||
# Handle exceptions
|
||||
if job.status.State.Name(job.status.state) == "ERROR":
|
||||
raise Exception(job.status.details)
|
||||
elif job.status.State.Name(job.status.state) == "DONE":
|
||||
print("Job finished.")
|
||||
return job
|
||||
|
||||
# Need to sleep a little or else we'll eat all the CPU
|
||||
time.sleep(0.1)
|
|
@ -1,73 +1,46 @@
|
|||
import asyncio
|
||||
from cfretl.asloader import ASLoader
|
||||
"""
|
||||
This script will create a cluster if required, and start the dataproc
|
||||
job to write out to a table.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pkg_resources
|
||||
|
||||
import click
|
||||
from cfretl.remote_settings import CFRRemoteSettings
|
||||
import random
|
||||
import requests
|
||||
|
||||
DELAY = 1
|
||||
from cfretl.dataproc import DataprocFacade
|
||||
|
||||
|
||||
def get_mock_vector():
|
||||
CFR_IDS = [
|
||||
"BOOKMARK_SYNC_CFR",
|
||||
"CRYPTOMINERS_PROTECTION",
|
||||
"CRYPTOMINERS_PROTECTION_71",
|
||||
"FACEBOOK_CONTAINER_3",
|
||||
"FACEBOOK_CONTAINER_3_72",
|
||||
"FINGERPRINTERS_PROTECTION",
|
||||
"FINGERPRINTERS_PROTECTION_71",
|
||||
"GOOGLE_TRANSLATE_3",
|
||||
"GOOGLE_TRANSLATE_3_72",
|
||||
"MILESTONE_MESSAGE",
|
||||
"PDF_URL_FFX_SEND",
|
||||
"PIN_TAB",
|
||||
"PIN_TAB_72",
|
||||
"SAVE_LOGIN",
|
||||
"SAVE_LOGIN_72",
|
||||
"SEND_RECIPE_TAB_CFR",
|
||||
"SEND_TAB_CFR",
|
||||
"SOCIAL_TRACKING_PROTECTION",
|
||||
"SOCIAL_TRACKING_PROTECTION_71",
|
||||
"WNP_MOMENTS_1",
|
||||
"WNP_MOMENTS_2",
|
||||
"WNP_MOMENTS_SYNC",
|
||||
"YOUTUBE_ENHANCE_3",
|
||||
"YOUTUBE_ENHANCE_3_72",
|
||||
]
|
||||
return dict(zip(CFR_IDS, [random.randint(0, 16000) for i in range(len(CFR_IDS))]))
|
||||
def load_mock_model():
|
||||
fname = pkg_resources.resource_filename("cfretl", "scripts/cfr_ml_model.json")
|
||||
mock_model = json.load(open(fname))["data"][0]
|
||||
return mock_model
|
||||
|
||||
|
||||
def bootstrap_test(cfr_rs):
|
||||
print("Installed CFR Control: {}".format(cfr_rs.clone_to_cfr_control(cfr_data())))
|
||||
print(
|
||||
"Installed CFR Experimetns: {}".format(
|
||||
cfr_rs.clone_to_cfr_experiment(cfr_data())
|
||||
)
|
||||
)
|
||||
@click.command()
|
||||
@click.option("--project-id", default="cfr-personalization-experiment")
|
||||
@click.option("--cluster-name", default="cfr-experiments")
|
||||
@click.option("--zone", default="us-west1-a")
|
||||
@click.option("--bucket-name", default="cfr-ml-jobs")
|
||||
@click.option("--spark-filename", default="compute_weights.py")
|
||||
def main(
|
||||
project_id=None, cluster_name=None, zone=None, bucket_name=None, spark_filename=None
|
||||
):
|
||||
dataproc = DataprocFacade(project_id, cluster_name, zone)
|
||||
dataproc.create_cluster_if_not_exists()
|
||||
|
||||
# Upload the script from teh cfretl.scripts directory
|
||||
dataproc.upload_sparkjob(bucket_name, spark_filename)
|
||||
dataproc.run_job(bucket_name, spark_filename)
|
||||
|
||||
def cfr_data():
|
||||
return requests.get(
|
||||
"https://firefox.settings.services.mozilla.com/v1/buckets/main/collections/cfr/records"
|
||||
).json()["data"]
|
||||
# TODO: do something to transform the bq result table
|
||||
# into a final model
|
||||
|
||||
remote_settings = CFRRemoteSettings()
|
||||
remote_settings.write_models(load_mock_model())
|
||||
|
||||
async def compute_models():
|
||||
# _ = asyncio.get_running_loop()
|
||||
|
||||
asloader = ASLoader()
|
||||
cfr_rs = CFRRemoteSettings()
|
||||
|
||||
# This sets up the test enviroment
|
||||
bootstrap_test(cfr_rs)
|
||||
|
||||
while True:
|
||||
_ = asloader.compute_vector_weights() # noqa
|
||||
write_status = cfr_rs.write_models(get_mock_vector())
|
||||
print("Write status: {}".format(write_status))
|
||||
# Wait to run the next batch
|
||||
await asyncio.sleep(DELAY)
|
||||
dataproc.destroy_cluster()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(compute_models())
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class CFRModel:
|
||||
def _one_cfr_model_feature(self, cfr_id, feature_id, p0, p1):
|
||||
"""
|
||||
Generate the JSON data structure for a single CFR feature (accept, reject) pair
|
||||
"""
|
||||
snippet = """{{"{cfr_id:s}": {{
|
||||
"{feature_id:s}": {{
|
||||
"p_given_cfr_acceptance": {p0:0.09f},
|
||||
"p_given_cfr_rejection": {p1:0.09f}
|
||||
}}}}}}"""
|
||||
|
||||
txt = snippet.format(cfr_id=cfr_id, feature_id=feature_id, p0=p0, p1=p1)
|
||||
jdata = json.loads(txt)
|
||||
return jdata
|
||||
|
||||
def generate_cfr_model(self, cfr_model_cfg, version):
|
||||
"""
|
||||
Generate the complete cfr-ml-models data
|
||||
"""
|
||||
model_cfrid = {}
|
||||
for cfr_id, cfr_cfg in cfr_model_cfg.items():
|
||||
if cfr_id not in model_cfrid:
|
||||
model_cfrid[cfr_id] = {}
|
||||
|
||||
prior_p0 = cfr_cfg["p0"]
|
||||
prior_p1 = cfr_cfg["p1"]
|
||||
|
||||
for feature in cfr_cfg["features"]:
|
||||
feature_dict = self._one_cfr_model_feature(
|
||||
cfr_id, feature["feature_id"], feature["p0"], feature["p1"]
|
||||
)
|
||||
|
||||
model_cfrid[cfr_id].update(feature_dict[cfr_id])
|
||||
model_cfrid[cfr_id].update(
|
||||
json.loads(
|
||||
"""{{"prior_cfr": {{"p_acceptance": {:0.09f}, "p_rejection": {:0.09f}}}}}""".format(
|
||||
prior_p0, prior_p1
|
||||
)
|
||||
)
|
||||
)
|
||||
model = {"models_by_cfr_id": model_cfrid, "version": version}
|
||||
return model
|
|
@ -10,13 +10,31 @@
|
|||
from cfretl import settings
|
||||
|
||||
import json
|
||||
import jsonschema
|
||||
|
||||
# import jsonschema
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
CFR_MODELS = "cfr-models"
|
||||
CFR_EXPERIMENT = "cfr-experiment"
|
||||
CFR_CONTROL = "cfr-control"
|
||||
CFR_MODELS = "cfr-ml-model"
|
||||
CFR_EXPERIMENT = "cfr-ml-experiments"
|
||||
CFR_CONTROL = "cfr-ml-control"
|
||||
|
||||
FEATURES_LIST = [
|
||||
"have_firefox_as_default_browser",
|
||||
"active_ticks",
|
||||
"total_uri_count",
|
||||
"about_preferences_non_default_value_count",
|
||||
"has_at_least_one_self_installed_addon",
|
||||
"has_at_least_one_self_installed_password_manager",
|
||||
"has_at_least_one_self_installed_theme",
|
||||
"dark_mode_active",
|
||||
"total_bookmarks_count",
|
||||
"has_at_least_two_logins_saved_in_the_browser",
|
||||
"firefox_accounts_configured",
|
||||
"locale",
|
||||
"profile_age",
|
||||
"main_monitor_screen_width",
|
||||
]
|
||||
|
||||
|
||||
class SecurityError(Exception):
|
||||
|
@ -40,6 +58,7 @@ class CFRRemoteSettings:
|
|||
'cfr-control' and 'cfr-models'.
|
||||
|
||||
See "CFR Machine Learning Experiment" doc for full details.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -119,6 +138,23 @@ class CFRRemoteSettings:
|
|||
raise
|
||||
return jdata
|
||||
|
||||
def debug_load_cfr(self):
|
||||
"""
|
||||
Read the production model 'cfr' collection.
|
||||
|
||||
This is only used for testing and initial setup of the
|
||||
collections.
|
||||
"""
|
||||
try:
|
||||
url = "https://firefox.settings.services.mozilla.com/v1/buckets/main/collections/cfr/records"
|
||||
resp = requests.get(url)
|
||||
jdata = resp.json()
|
||||
return jdata["data"]
|
||||
except Exception:
|
||||
# This method is only used for testing purposes - it's
|
||||
# safe to just re-raise the exception here
|
||||
raise
|
||||
|
||||
def _test_read_models(self):
|
||||
"""
|
||||
Read the model from RemoteSettings. This method is only used
|
||||
|
@ -139,7 +175,8 @@ class CFRRemoteSettings:
|
|||
return jdata
|
||||
|
||||
def write_models(self, json_data):
|
||||
jsonschema.validate(json_data, self.schema)
|
||||
# TODO: we need a new schema validator
|
||||
# jsonschema.validate(json_data, self.schema)
|
||||
if not self.check_model_exists():
|
||||
if not self.create_model_collection():
|
||||
raise SecurityError("cfr-model collection could not be created.")
|
||||
|
@ -205,20 +242,66 @@ class CFRRemoteSettings:
|
|||
"{} collection could not be created.".format(CFR_EXPERIMENT)
|
||||
)
|
||||
|
||||
if not self._clone_cfr_to(cfr_data, CFR_EXPERIMENT):
|
||||
return False
|
||||
# Test CFR Message added from test plan
|
||||
# https://docs.google.com/document/d/1_aPEj_XS83qzDphOVGWk70vkbaACd270Hn9I4fY7fuE/edit#heading=h.77k16ftk1hea
|
||||
test_cfr = {
|
||||
"id": "PERSONALIZED_CFR_MESSAGE",
|
||||
"template": "cfr_doorhanger",
|
||||
"content": {
|
||||
"layout": "icon_and_message",
|
||||
"category": "cfrFeatures",
|
||||
"notification_text": "Personalized CFR Recommendation",
|
||||
"heading_text": {"string_id": "cfr-doorhanger-firefox-send-header"},
|
||||
"info_icon": {
|
||||
"label": {"string_id": "cfr-doorhanger-extension-sumo-link"},
|
||||
"sumo_path": "https://example.com",
|
||||
},
|
||||
"text": {"string_id": "cfr-doorhanger-firefox-send-body"},
|
||||
"icon": "chrome://branding/content/icon64.png",
|
||||
"buttons": {
|
||||
"primary": {
|
||||
"label": {"string_id": "cfr-doorhanger-firefox-send-ok-button"},
|
||||
"action": {
|
||||
"type": "OPEN_URL",
|
||||
"data": {
|
||||
"args": "https://send.firefox.com/login/?utm_source=activity-stream&entrypoint=activity-stream-cfr-pdf",
|
||||
"where": "tabshifted",
|
||||
},
|
||||
},
|
||||
},
|
||||
"secondary": [
|
||||
{
|
||||
"label": {
|
||||
"string_id": "cfr-doorhanger-extension-cancel-button"
|
||||
},
|
||||
"action": {"type": "CANCEL"},
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"string_id": "cfr-doorhanger-extension-never-show-recommendation"
|
||||
}
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"string_id": "cfr-doorhanger-extension-manage-settings-button"
|
||||
},
|
||||
"action": {
|
||||
"type": "OPEN_PREFERENCES_PAGE",
|
||||
"data": {"category": "general-cfrfeatures"},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
"targeting": "scores.PERSONALIZED_CFR_MESSAGE > scoreThreshold",
|
||||
"trigger": {"id": "openURL", "patterns": ["*://*/*.pdf"]},
|
||||
}
|
||||
|
||||
# Write in the targetting attribute
|
||||
|
||||
auth = HTTPBasicAuth(self._kinto_user, self._kinto_pass)
|
||||
obj_id = "targetting"
|
||||
url = "{base_uri:s}/buckets/main/collections/{c_id:s}/records/{obj_id:s}".format(
|
||||
base_uri=self._kinto_uri, c_id=CFR_EXPERIMENT, obj_id=obj_id
|
||||
)
|
||||
obj = {"targetting": "scores.PERSONALIZED_CFR_MESSAGE > scoreThreshold"}
|
||||
resp = requests.put(url, json={"data": obj}, auth=auth)
|
||||
if resp.status_code > 299:
|
||||
raise RemoteSettingWriteError(
|
||||
"Error writing targetting expression to experiment bucket"
|
||||
for record in cfr_data:
|
||||
record[
|
||||
"targeting"
|
||||
] = "({old_targeting:s}) && personalizedCfrScores.{id:s} > personalizedCfrThreshold".format(
|
||||
old_targeting=record["targeting"], id=record["id"]
|
||||
)
|
||||
return True
|
||||
cfr_data.append(test_cfr)
|
||||
return self._clone_cfr_to(cfr_data, CFR_EXPERIMENT)
|
||||
|
|
|
@ -36,219 +36,168 @@
|
|||
"type": "integer",
|
||||
"title": "The Bookmark_sync_cfr Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
10476
|
||||
]
|
||||
"examples": [10476]
|
||||
},
|
||||
"CRYPTOMINERS_PROTECTION": {
|
||||
"$id": "#/properties/CRYPTOMINERS_PROTECTION",
|
||||
"type": "integer",
|
||||
"title": "The Cryptominers_protection Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
824
|
||||
]
|
||||
"examples": [824]
|
||||
},
|
||||
"CRYPTOMINERS_PROTECTION_71": {
|
||||
"$id": "#/properties/CRYPTOMINERS_PROTECTION_71",
|
||||
"type": "integer",
|
||||
"title": "The Cryptominers_protection_71 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
409
|
||||
]
|
||||
"examples": [409]
|
||||
},
|
||||
"FACEBOOK_CONTAINER_3": {
|
||||
"$id": "#/properties/FACEBOOK_CONTAINER_3",
|
||||
"type": "integer",
|
||||
"title": "The Facebook_container_3 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
12149
|
||||
]
|
||||
"examples": [12149]
|
||||
},
|
||||
"FACEBOOK_CONTAINER_3_72": {
|
||||
"$id": "#/properties/FACEBOOK_CONTAINER_3_72",
|
||||
"type": "integer",
|
||||
"title": "The Facebook_container_3_72 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
4506
|
||||
]
|
||||
"examples": [4506]
|
||||
},
|
||||
"FINGERPRINTERS_PROTECTION": {
|
||||
"$id": "#/properties/FINGERPRINTERS_PROTECTION",
|
||||
"type": "integer",
|
||||
"title": "The Fingerprinters_protection Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
4012
|
||||
]
|
||||
"examples": [4012]
|
||||
},
|
||||
"FINGERPRINTERS_PROTECTION_71": {
|
||||
"$id": "#/properties/FINGERPRINTERS_PROTECTION_71",
|
||||
"type": "integer",
|
||||
"title": "The Fingerprinters_protection_71 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
3657
|
||||
]
|
||||
"examples": [3657]
|
||||
},
|
||||
"GOOGLE_TRANSLATE_3": {
|
||||
"$id": "#/properties/GOOGLE_TRANSLATE_3",
|
||||
"type": "integer",
|
||||
"title": "The Google_translate_3 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
2286
|
||||
]
|
||||
"examples": [2286]
|
||||
},
|
||||
"GOOGLE_TRANSLATE_3_72": {
|
||||
"$id": "#/properties/GOOGLE_TRANSLATE_3_72",
|
||||
"type": "integer",
|
||||
"title": "The Google_translate_3_72 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
12066
|
||||
]
|
||||
"examples": [12066]
|
||||
},
|
||||
"MILESTONE_MESSAGE": {
|
||||
"$id": "#/properties/MILESTONE_MESSAGE",
|
||||
"type": "integer",
|
||||
"title": "The Milestone_message Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
1679
|
||||
]
|
||||
"examples": [1679]
|
||||
},
|
||||
"PDF_URL_FFX_SEND": {
|
||||
"$id": "#/properties/PDF_URL_FFX_SEND",
|
||||
"type": "integer",
|
||||
"title": "The Pdf_url_ffx_send Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
11087
|
||||
]
|
||||
"examples": [11087]
|
||||
},
|
||||
"PIN_TAB": {
|
||||
"$id": "#/properties/PIN_TAB",
|
||||
"type": "integer",
|
||||
"title": "The Pin_tab Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
12135
|
||||
]
|
||||
"examples": [12135]
|
||||
},
|
||||
"PIN_TAB_72": {
|
||||
"$id": "#/properties/PIN_TAB_72",
|
||||
"type": "integer",
|
||||
"title": "The Pin_tab_72 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
14617
|
||||
]
|
||||
"examples": [14617]
|
||||
},
|
||||
"SAVE_LOGIN": {
|
||||
"$id": "#/properties/SAVE_LOGIN",
|
||||
"type": "integer",
|
||||
"title": "The Save_login Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
8935
|
||||
]
|
||||
"examples": [8935]
|
||||
},
|
||||
"SAVE_LOGIN_72": {
|
||||
"$id": "#/properties/SAVE_LOGIN_72",
|
||||
"type": "integer",
|
||||
"title": "The Save_login_72 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
1424
|
||||
]
|
||||
"examples": [1424]
|
||||
},
|
||||
"SEND_RECIPE_TAB_CFR": {
|
||||
"$id": "#/properties/SEND_RECIPE_TAB_CFR",
|
||||
"type": "integer",
|
||||
"title": "The Send_recipe_tab_cfr Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
9674
|
||||
]
|
||||
"examples": [9674]
|
||||
},
|
||||
"SEND_TAB_CFR": {
|
||||
"$id": "#/properties/SEND_TAB_CFR",
|
||||
"type": "integer",
|
||||
"title": "The Send_tab_cfr Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
6912
|
||||
]
|
||||
"examples": [6912]
|
||||
},
|
||||
"SOCIAL_TRACKING_PROTECTION": {
|
||||
"$id": "#/properties/SOCIAL_TRACKING_PROTECTION",
|
||||
"type": "integer",
|
||||
"title": "The Social_tracking_protection Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
520
|
||||
]
|
||||
"examples": [520]
|
||||
},
|
||||
"SOCIAL_TRACKING_PROTECTION_71": {
|
||||
"$id": "#/properties/SOCIAL_TRACKING_PROTECTION_71",
|
||||
"type": "integer",
|
||||
"title": "The Social_tracking_protection_71 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
488
|
||||
]
|
||||
"examples": [488]
|
||||
},
|
||||
"WNP_MOMENTS_1": {
|
||||
"$id": "#/properties/WNP_MOMENTS_1",
|
||||
"type": "integer",
|
||||
"title": "The Wnp_moments_1 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
1535
|
||||
]
|
||||
"examples": [1535]
|
||||
},
|
||||
"WNP_MOMENTS_2": {
|
||||
"$id": "#/properties/WNP_MOMENTS_2",
|
||||
"type": "integer",
|
||||
"title": "The Wnp_moments_2 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
3582
|
||||
]
|
||||
"examples": [3582]
|
||||
},
|
||||
"WNP_MOMENTS_SYNC": {
|
||||
"$id": "#/properties/WNP_MOMENTS_SYNC",
|
||||
"type": "integer",
|
||||
"title": "The Wnp_moments_sync Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
3811
|
||||
]
|
||||
"examples": [3811]
|
||||
},
|
||||
"YOUTUBE_ENHANCE_3": {
|
||||
"$id": "#/properties/YOUTUBE_ENHANCE_3",
|
||||
"type": "integer",
|
||||
"title": "The Youtube_enhance_3 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
8279
|
||||
]
|
||||
"examples": [8279]
|
||||
},
|
||||
"YOUTUBE_ENHANCE_3_72": {
|
||||
"$id": "#/properties/YOUTUBE_ENHANCE_3_72",
|
||||
"type": "integer",
|
||||
"title": "The Youtube_enhance_3_72 Schema",
|
||||
"default": 0,
|
||||
"examples": [
|
||||
9863
|
||||
]
|
||||
"examples": [9863]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
workerConfig:
|
||||
minInstances: 2 # this is a GCP minimum
|
||||
maxInstances: 2 # this is a GCP minimum
|
||||
weight: 1
|
||||
secondaryWorkerConfig:
|
||||
minInstances: 0 # zero pre-emptible nodes are allowed
|
||||
maxInstances: 20
|
||||
weight: 10
|
||||
basicAlgorithm:
|
||||
cooldownPeriod: 4m
|
||||
yarnConfig:
|
||||
scaleUpFactor: 0.05
|
||||
scaleDownFactor: 1.0
|
||||
scaleUpMinWorkerFraction: 0.0
|
||||
scaleDownMinWorkerFraction: 0.0
|
||||
gracefulDecommissionTimeout: 1h
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
from pyspark.sql import SparkSession
|
||||
from sklearn import metrics
|
||||
|
||||
# Import numerical computation libraries
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.naive_bayes import GaussianNB, CategoricalNB
|
||||
from sklearn.preprocessing import OrdinalEncoder
|
||||
import numpy as np
|
||||
import pandas as pd # noqa
|
||||
|
||||
|
||||
def check_sklearn_dev():
|
||||
'''
|
||||
This just verifies that sklearn 0.23-dev is installed properly
|
||||
by checking CategoricalNB results
|
||||
'''
|
||||
rng = np.random.RandomState(1)
|
||||
X = rng.randint(5, size=(6, 100))
|
||||
y = np.array([1, 2, 3, 4, 5, 6])
|
||||
|
||||
clf = CategoricalNB()
|
||||
clf.fit(X, y)
|
||||
assert [3] == clf.predict(X[2:3])
|
||||
|
||||
def spark_query_bq():
|
||||
spark = (
|
||||
SparkSession.builder.master("yarn").appName("spark-bigquery-demo").getOrCreate()
|
||||
)
|
||||
|
||||
# Use the Cloud Storage bucket for temporary BigQuery export data used
|
||||
# by the connector. This assumes the Cloud Storage connector for
|
||||
# Hadoop is configured.
|
||||
bucket = spark.sparkContext._jsc.hadoopConfiguration().get("fs.gs.system.bucket")
|
||||
spark.conf.set("temporaryGcsBucket", bucket)
|
||||
|
||||
str_ts = "2019-11-24 00:00:00.000000 UTC"
|
||||
end_ts = "2019-11-25 00:00:00.000000 UTC"
|
||||
|
||||
# Load data from BigQuery.
|
||||
df = (
|
||||
spark.read.format("bigquery")
|
||||
.option("table", "moz-fx-data-shar-nonprod-efed:messaging_system_live.cfr_v1")
|
||||
.option("filter", "submission_timestamp >= '{str_ts:s}'".format(str_ts=str_ts))
|
||||
.option("filter", "submission_timestamp >= '{end_ts:s}'".format(end_ts=end_ts))
|
||||
.load()
|
||||
)
|
||||
|
||||
df.createOrReplaceTempView("cfr_v1")
|
||||
|
||||
sql_template = """select * from cfr_v1"""
|
||||
|
||||
# Demonstrate we can hook a pyspark dataframe here
|
||||
sql = sql_template.format(str_ts=str_ts, end_ts=end_ts)
|
||||
row_df = spark.sql(sql)
|
||||
|
||||
print("Fetched {:d} rows from bq via spark".format(row_df.count()))
|
||||
|
||||
# Create a bunch of dummy CFR_IDs with weights
|
||||
row_df = spark.createDataFrame(
|
||||
[("CFR_1", 1), ("CFR_2", 5), ("CFR_3", 9), ("CFR_4", 21), ("CFR_5", 551)]
|
||||
)
|
||||
return row_df
|
||||
|
||||
|
||||
check_sklearn_dev()
|
||||
row_df = spark_query_bq()
|
||||
# Stuff the results into a versioned table
|
||||
row_df.write.format("bigquery").mode("overwrite").option(
|
||||
"table", "cfr_etl.cfr_weights_v010"
|
||||
).save()
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -exo pipefail
|
||||
|
||||
function install_pip() {
|
||||
if command -v pip >/dev/null; then
|
||||
echo "pip is already installed."
|
||||
return 0
|
||||
fi
|
||||
|
||||
if command -v easy_install >/dev/null; then
|
||||
echo "Installing pip with easy_install..."
|
||||
easy_install pip
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "Installing python-pip..."
|
||||
apt update
|
||||
apt install python-pip -y
|
||||
}
|
||||
|
||||
function main() {
|
||||
install_pip
|
||||
pip install --upgrade --pre -f https://sklearn-nightly.scdn8.secure.raxcdn.com scikit-learn
|
||||
pip install google-cloud-bigquery==1.21.0 pandas==0.25.3 numpy==1.17.3
|
||||
}
|
||||
|
||||
main
|
|
@ -0,0 +1,46 @@
|
|||
"""
|
||||
This script installs fixture test data into a kinto server
|
||||
"""
|
||||
|
||||
from cfretl.remote_settings import CFRRemoteSettings
|
||||
from cfretl.remote_settings import FEATURES_LIST
|
||||
from cfretl.models import CFRModel
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
remote_settings = CFRRemoteSettings()
|
||||
|
||||
version_code = int(datetime.utcnow().strftime("%Y%m%d%H%m%S"))
|
||||
|
||||
CFRS = remote_settings.debug_load_cfr()
|
||||
CFR_ID_LIST = [r["id"] for r in CFRS]
|
||||
|
||||
|
||||
def generate_cfr_cfgdata():
|
||||
"""
|
||||
This function will need to be rewritten to parse the
|
||||
BQ output table and coerce it into values for RemoteSettings JSON
|
||||
blob
|
||||
"""
|
||||
cfg_data = {}
|
||||
for cfr_id in CFR_ID_LIST:
|
||||
|
||||
# TODO: replace this with prior 0
|
||||
p0 = random.random()
|
||||
|
||||
cfg_data[cfr_id] = {"p0": p0, "p1": 1 - p0, "features": []}
|
||||
for f_id in FEATURES_LIST:
|
||||
cfg_data[cfr_id]["features"].append(
|
||||
{"feature_id": f_id, "p0": random.random(), "p1": random.random()}
|
||||
)
|
||||
return cfg_data
|
||||
|
||||
|
||||
model = CFRModel()
|
||||
json_model = model.generate_cfr_model(generate_cfr_cfgdata(), version_code)
|
||||
remote_settings.write_models(json_model)
|
||||
|
||||
remote_settings.clone_to_cfr_control(CFRS)
|
||||
remote_settings.clone_to_cfr_experiment(CFRS)
|
||||
print("Wrote out version : {:d}".format(version_code))
|
||||
print("=" * 20, datetime.now(), "=" * 20)
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/sh
|
||||
SERVER=https://kinto.dev.mozaws.net/v1
|
||||
|
||||
while true
|
||||
do
|
||||
# Create the bot first to start clean
|
||||
curl -v -X PUT ${SERVER}/accounts/devuser \
|
||||
-d '{"data": {"password": "devpass"}}' \
|
||||
-H 'Content-Type:application/json'
|
||||
python install_fixtures.py
|
||||
sleep 10
|
||||
done
|
|
@ -5,8 +5,6 @@
|
|||
from decouple import config
|
||||
|
||||
|
||||
# Default CFR Vector width is 7
|
||||
|
||||
KINTO_BUCKET = "main"
|
||||
|
||||
KINTO_URI = config("KINTO_URI", "https://kinto.dev.mozaws.net/v1")
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
from google.cloud.dataproc_v1.gapic.transports import cluster_controller_grpc_transport
|
||||
from google.cloud import dataproc_v1
|
||||
|
||||
waiting_callback = False
|
||||
|
||||
|
||||
def cluster_callback(operation_future):
|
||||
# Reset global when callback returns.
|
||||
global waiting_callback
|
||||
waiting_callback = False
|
||||
|
||||
|
||||
def get_region_from_zone(zone):
|
||||
try:
|
||||
region_as_list = zone.split("-")[:-1]
|
||||
return "-".join(region_as_list)
|
||||
except (AttributeError, IndexError, ValueError):
|
||||
raise ValueError("Invalid zone provided, please check your input.")
|
||||
|
||||
|
||||
def dataproc_cluster_client(zone):
|
||||
"""
|
||||
Lazily create a Dataproc ClusterController client to setup or
|
||||
tear down dataproc clusters
|
||||
"""
|
||||
region = get_region_from_zone(zone)
|
||||
|
||||
client_transport = cluster_controller_grpc_transport.ClusterControllerGrpcTransport(
|
||||
address="{}-dataproc.googleapis.com:443".format(region)
|
||||
)
|
||||
return dataproc_v1.ClusterControllerClient(client_transport)
|
||||
|
||||
|
||||
def create_cluster(cluster_name, project_id, zone):
|
||||
"""Create the cluster."""
|
||||
|
||||
# TODO: pass in the bucket somehow. maybe an attribute
|
||||
# settings
|
||||
# bucket_name = "cfr-ml-jobs"
|
||||
|
||||
print("Creating cluster...")
|
||||
cluster_config = {
|
||||
"cluster_name": cluster_name,
|
||||
"project_id": project_id,
|
||||
"config": {
|
||||
"master_config": {
|
||||
"num_instances": 1,
|
||||
"machine_type_uri": "n1-standard-1",
|
||||
"disk_config": {
|
||||
"boot_disk_type": "pd-ssd",
|
||||
"num_local_ssds": 1,
|
||||
"boot_disk_size_gb": 1000,
|
||||
},
|
||||
},
|
||||
"worker_config": {
|
||||
"num_instances": 2,
|
||||
"machine_type_uri": "n1-standard-8",
|
||||
"disk_config": {
|
||||
"boot_disk_type": "pd-ssd",
|
||||
"num_local_ssds": 1,
|
||||
"boot_disk_size_gb": 1000,
|
||||
},
|
||||
},
|
||||
"autoscaling_config": {
|
||||
"policy_uri": "projects/cfr-personalization-experiment/regions/us-west1/autoscalingPolicies/cfr-personalization-autoscale"
|
||||
},
|
||||
"initialization_actions": [
|
||||
{
|
||||
"executable_file": "gs://cfr-ml-jobs/actions/python/dataproc_custom.sh"
|
||||
}
|
||||
],
|
||||
"software_config": {"image_version": "1.4.16-ubuntu18"},
|
||||
},
|
||||
}
|
||||
|
||||
cluster = dataproc_cluster_client(zone).create_cluster(
|
||||
project_id, get_region_from_zone(zone), cluster_config
|
||||
)
|
||||
cluster.add_done_callback(cluster_callback)
|
||||
global waiting_callback
|
||||
waiting_callback = True
|
||||
|
||||
|
||||
def wait_for_cluster_creation():
|
||||
"""Wait for cluster creation."""
|
||||
print("Waiting for cluster creation...")
|
||||
|
||||
while True:
|
||||
if not waiting_callback:
|
||||
print("Cluster created.")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cluster_name = "cfr-demo-cluster-py3"
|
||||
project_id = "cfr-personalization-experiment"
|
||||
zone = "us-west1-a"
|
||||
create_cluster(cluster_name, project_id, zone)
|
||||
wait_for_cluster_creation()
|
|
@ -1,24 +1,22 @@
|
|||
# CFR Machine Learning
|
||||
# CFR Machine Learning
|
||||
|
||||
[![Build Status](https://travis-ci.org/mozilla/cfr-personalization.svg?branch=master)](https://travis-ci.org/mozilla/cfr-personalization)
|
||||
|
||||
Table of Contents (ToC):
|
||||
===========================
|
||||
# Table of Contents (ToC):
|
||||
|
||||
* [How does it work?](#how-does-it-work)
|
||||
* [Building and Running tests](#build-and-run-tests)
|
||||
- [How does it work?](#how-does-it-work)
|
||||
- [Building and Running tests](#build-and-run-tests)
|
||||
|
||||
## How does it work?
|
||||
|
||||
CFR instrumentation works by reading telemetry pings directly from
|
||||
'live' tables.
|
||||
'live' tables.
|
||||
|
||||
Those pings go through a ML pass to generate a new set of weights and
|
||||
we write directly into Remote Settings.
|
||||
|
||||
TODO: write more about ML layer here
|
||||
|
||||
|
||||
Some terminology is important:
|
||||
|
||||
In the context of Remote Settings, CFR uses the 'main' bucket.
|
||||
|
@ -29,18 +27,12 @@ Each 'Provider ID' in the about:newtab#devtools page is called a
|
|||
To minimize impact on production, we constrain the places where
|
||||
we can write into the 'main' bucket.
|
||||
|
||||
|
||||
![Collections are not 'Buckets'](./rs_collections.jpg "Collections are not Buckets")
|
||||
|
||||
|
||||
CFR-Personalization will:
|
||||
* *only* operate on collections within the 'main' bucket
|
||||
* *only* write to buckets with a prefix 'cfr-exp-'
|
||||
* all writes to a collection will first be validated
|
||||
|
||||
|
||||
CFR-Personalization will: \* _only_ operate on collections within the 'main' bucket \* _only_ write to buckets with a prefix 'cfr-exp-' \* all writes to a collection will first be validated
|
||||
|
||||
## Building and running tests
|
||||
|
||||
You should be able to build cfr-personalization using Python 3.7
|
||||
|
||||
To run the testsuite, execute ::
|
||||
|
@ -49,4 +41,3 @@ To run the testsuite, execute ::
|
|||
$ python setup.py develop
|
||||
$ python setup.py test
|
||||
```
|
||||
|
||||
|
|
|
@ -2,30 +2,32 @@ The CFR weight vector schema is computed by using the current set of
|
|||
CFR message IDs as keys and an integer value in a hashmap.
|
||||
|
||||
```json
|
||||
{'BOOKMARK_SYNC_CFR': 10476,
|
||||
'CRYPTOMINERS_PROTECTION': 1824,
|
||||
'CRYPTOMINERS_PROTECTION_71': 409,
|
||||
'FACEBOOK_CONTAINER_3': 12149,
|
||||
'FACEBOOK_CONTAINER_3_72': 4506,
|
||||
'FINGERPRINTERS_PROTECTION': 4012,
|
||||
'FINGERPRINTERS_PROTECTION_71': 3657,
|
||||
'GOOGLE_TRANSLATE_3': 2286,
|
||||
'GOOGLE_TRANSLATE_3_72': 12066,
|
||||
'MILESTONE_MESSAGE': 1679,
|
||||
'PDF_URL_FFX_SEND': 11087,
|
||||
'PIN_TAB': 12135,
|
||||
'PIN_TAB_72': 14617,
|
||||
'SAVE_LOGIN': 8935,
|
||||
'SAVE_LOGIN_72': 1424,
|
||||
'SEND_RECIPE_TAB_CFR': 9674,
|
||||
'SEND_TAB_CFR': 6912,
|
||||
'SOCIAL_TRACKING_PROTECTION': 520,
|
||||
'SOCIAL_TRACKING_PROTECTION_71': 488,
|
||||
'WNP_MOMENTS_1': 1535,
|
||||
'WNP_MOMENTS_2': 3582,
|
||||
'WNP_MOMENTS_SYNC': 3811,
|
||||
'YOUTUBE_ENHANCE_3': 8279,
|
||||
'YOUTUBE_ENHANCE_3_72': 9863}
|
||||
{
|
||||
"BOOKMARK_SYNC_CFR": 10476,
|
||||
"CRYPTOMINERS_PROTECTION": 1824,
|
||||
"CRYPTOMINERS_PROTECTION_71": 409,
|
||||
"FACEBOOK_CONTAINER_3": 12149,
|
||||
"FACEBOOK_CONTAINER_3_72": 4506,
|
||||
"FINGERPRINTERS_PROTECTION": 4012,
|
||||
"FINGERPRINTERS_PROTECTION_71": 3657,
|
||||
"GOOGLE_TRANSLATE_3": 2286,
|
||||
"GOOGLE_TRANSLATE_3_72": 12066,
|
||||
"MILESTONE_MESSAGE": 1679,
|
||||
"PDF_URL_FFX_SEND": 11087,
|
||||
"PIN_TAB": 12135,
|
||||
"PIN_TAB_72": 14617,
|
||||
"SAVE_LOGIN": 8935,
|
||||
"SAVE_LOGIN_72": 1424,
|
||||
"SEND_RECIPE_TAB_CFR": 9674,
|
||||
"SEND_TAB_CFR": 6912,
|
||||
"SOCIAL_TRACKING_PROTECTION": 520,
|
||||
"SOCIAL_TRACKING_PROTECTION_71": 488,
|
||||
"WNP_MOMENTS_1": 1535,
|
||||
"WNP_MOMENTS_2": 3582,
|
||||
"WNP_MOMENTS_SYNC": 3811,
|
||||
"YOUTUBE_ENHANCE_3": 8279,
|
||||
"YOUTUBE_ENHANCE_3_72": 9863
|
||||
}
|
||||
```
|
||||
|
||||
The JSON schema is computed using https://jsonschema.net/
|
||||
|
@ -33,4 +35,3 @@ The JSON schema is computed using https://jsonschema.net/
|
|||
Note that the schema generated will enforce the length of the vector
|
||||
as each CFR Message ID a required key in the JSON blob that is
|
||||
returned.
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
appnope==0.1.0
|
||||
astroid==2.3.3
|
||||
atomicwrites==1.3.0
|
||||
attrs==19.3.0
|
||||
backcall==0.1.0
|
||||
|
@ -13,17 +14,23 @@ google-api-core==1.14.3
|
|||
google-auth==1.7.0
|
||||
google-cloud-bigquery==1.21.0
|
||||
google-cloud-core==1.0.3
|
||||
google-resumable-media==0.4.1
|
||||
google-cloud-dataproc==0.6.1
|
||||
google-cloud-storage==1.23.0
|
||||
google-resumable-media==0.5.0
|
||||
googleapis-common-protos==1.6.0
|
||||
grpcio==1.25.0
|
||||
idna==2.8
|
||||
importlib-metadata==0.23
|
||||
ipython==7.9.0
|
||||
ipython-genutils==0.2.0
|
||||
isort==4.3.21
|
||||
jedi==0.15.1
|
||||
joblib==0.14.0
|
||||
jsonschema==3.1.1
|
||||
lazy-object-proxy==1.4.3
|
||||
mccabe==0.6.1
|
||||
more-itertools==7.2.0
|
||||
-e git+git@github.com:mozilla/messaging-system-personalization-experiment-1-numbermuncher.git@5b44909f8b1392efce6490e2b42a2ac9be5e8477#egg=mozilla_cfr_personalization
|
||||
numpy==1.17.3
|
||||
packaging==19.2
|
||||
pandas==0.25.3
|
||||
|
@ -35,13 +42,16 @@ prompt-toolkit==2.0.10
|
|||
protobuf==3.10.0
|
||||
ptyprocess==0.6.0
|
||||
py==1.8.0
|
||||
py4j==0.10.7
|
||||
pyasn1==0.4.7
|
||||
pyasn1-modules==0.2.7
|
||||
pycodestyle==2.5.0
|
||||
pyflakes==2.1.1
|
||||
Pygments==2.4.2
|
||||
pylint==2.4.4
|
||||
pyparsing==2.4.4
|
||||
pyrsistent==0.15.5
|
||||
pyspark==2.4.4
|
||||
pytest==5.2.2
|
||||
pytest-cov==2.8.1
|
||||
python-dateutil==2.8.1
|
||||
|
@ -49,12 +59,15 @@ python-decouple==3.1
|
|||
pytz==2019.3
|
||||
requests==2.22.0
|
||||
requests-mock==1.7.0
|
||||
rope==0.14.0
|
||||
rsa==4.0
|
||||
scikit-learn==0.21.3
|
||||
scipy==1.3.1
|
||||
six==1.13.0
|
||||
sklearn==0.0
|
||||
traitlets==4.3.3
|
||||
typed-ast==1.4.0
|
||||
urllib3==1.25.6
|
||||
wcwidth==0.1.7
|
||||
wrapt==1.11.2
|
||||
zipp==0.6.0
|
||||
|
|
|
@ -87,7 +87,6 @@
|
|||
"enum": ["daily"]
|
||||
}
|
||||
]
|
||||
|
||||
},
|
||||
"cap": {
|
||||
"type": "integer",
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
#!/bin/sh
|
||||
|
||||
# Create a cfr-bot with a dummy password for local testing
|
||||
SERVER=http://localhost:8888/v1
|
||||
|
||||
curl -X PUT ${SERVER}/accounts/cfr-bot \
|
||||
-d '{"data": {"password": "botpass"}}' \
|
||||
-H 'Content-Type:application/json'
|
||||
|
||||
BASIC_AUTH=cfr-bot:botpass
|
||||
|
||||
# Create 3 collections
|
||||
# * main/cfr-models
|
||||
# * main/cfr-experiment
|
||||
# * main/cfr-control
|
||||
|
||||
curl -X PUT ${SERVER}/buckets/main/collections/cfr-models \
|
||||
-H 'Content-Type:application/json' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
curl -X PUT ${SERVER}/buckets/main/collections/cfr-experiment \
|
||||
-H 'Content-Type:application/json' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
curl -X PUT ${SERVER}/buckets/main/collections/cfr-control \
|
||||
-H 'Content-Type:application/json' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
# Add the bot to editor role for 3 collections:
|
||||
# * main/cfr-models
|
||||
# * main/cfr-experiment
|
||||
# * main/cfr-control
|
||||
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-models-editors \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-experiment-editors \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-control-editors \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
# Add the bot to reviewer role for 3 collections:
|
||||
# * main/cfr-models
|
||||
# * main/cfr-experiment
|
||||
# * main/cfr-control
|
||||
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-models-reviewers \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-experiment-reviewers \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
curl -X PATCH $SERVER/buckets/main/groups/cfr-control-reviewers \
|
||||
-H 'Content-Type:application/json-patch+json' \
|
||||
-d '[{ "op": "add", "path": "/data/members/0", "value": "account:cfr-bot" }]' \
|
||||
-u ${BASIC_AUTH}
|
||||
|
||||
# Generate some dummy data in the cfr-models bucket
|
||||
|
||||
curl -X PUT ${SERVER}/buckets/main/collections/cfr-models/records/cfr-models \
|
||||
-H 'Content-Type:application/json' \
|
||||
-d "{\"data\": {\"property\": 321.1}}" \
|
||||
-u ${BASIC_AUTH} --verbose
|
3
setup.py
3
setup.py
|
@ -1,7 +1,7 @@
|
|||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="mozilla-cfr-personalization",
|
||||
name="CFR Personalization",
|
||||
use_scm_version=False,
|
||||
version="0.1.0",
|
||||
setup_requires=["setuptools_scm", "pytest-runner"],
|
||||
|
@ -14,6 +14,7 @@ setup(
|
|||
url="https://github.com/mozilla/cfr-personalization",
|
||||
license="MPL 2.0",
|
||||
install_requires=[],
|
||||
data_files=[("scripts", ["scripts/compute_weights.py"])],
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Web Environment :: Mozilla",
|
||||
|
|
|
@ -2,13 +2,11 @@
|
|||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
import datetime
|
||||
import json
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from cfretl.asloader import ASLoader
|
||||
|
||||
import os
|
||||
import random
|
||||
|
@ -93,25 +91,52 @@ def FIXTURE_JSON():
|
|||
"value": '{"card_type": "pinned", "icon_type": "screenshot_with_icon"}',
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def asloader(FIXTURE_JSON, WEIGHT_VECTOR):
|
||||
asl = ASLoader()
|
||||
|
||||
# Clobber the inbound pings
|
||||
asl._get_pings = MagicMock(return_value=[FIXTURE_JSON])
|
||||
|
||||
# Clobber the model generation
|
||||
asl.compute_vector_weights = MagicMock(return_value=WEIGHT_VECTOR)
|
||||
|
||||
return asl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def MOCK_CFR_DATA():
|
||||
return json.load(open(os.path.join(FIXTURE_PATH, "cfr.json")))["data"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def WEIGHT_VECTOR():
|
||||
return dict(zip(CFR_IDS, [random.randint(0, 16000) for i in range(len(CFR_IDS))]))
|
||||
return {
|
||||
"CFR_ID_1": {
|
||||
"feature_1": {"p_given_cfr_acceptance": 0.7, "p_given_cfr_rejection": 0.5}
|
||||
},
|
||||
"CFR_ID_2": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.49,
|
||||
"p_given_cfr_rejection": 0.25,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.49,
|
||||
"p_given_cfr_rejection": 0.25,
|
||||
},
|
||||
},
|
||||
"CFR_ID_3": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
"feature_3": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
},
|
||||
"CFR_ID_4": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_3": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_4": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# The cfr.json file is a dump from the production Firefox settings
|
||||
|
||||
# server
|
||||
|
||||
wget https://firefox.settings.services.mozilla.com/v1/buckets/main/collections/cfr/records > cfr.json
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -1,42 +0,0 @@
|
|||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
"""
|
||||
These tests exercise the connector to Remote Settings
|
||||
"""
|
||||
import pytz
|
||||
from cfretl.asloader import ASLoader
|
||||
import datetime
|
||||
|
||||
|
||||
def test_build_query():
|
||||
"""
|
||||
Test that we can fetch a single hour worth of data
|
||||
"""
|
||||
asloader = ASLoader()
|
||||
d = datetime.datetime(2019, 11, 27, 12, 11, 0, tzinfo=pytz.utc)
|
||||
sql = asloader._build_query(d)
|
||||
# The minute section should be zeroed out
|
||||
expected = "select * from `moz-fx-data-bq-srg`.tiles.assa_router_events_daily where receive_at >= '2019-11-27 12:00:00' and receive_at <= '2019-11-27 13:00:00'"
|
||||
assert expected == sql
|
||||
|
||||
|
||||
def test_fetch_pings(asloader, FIXTURE_JSON):
|
||||
# Test that our mock ASLoader returns a single ping
|
||||
dt = datetime.datetime(2018, 9, 30, 10, 1, tzinfo=pytz.utc)
|
||||
row_iter = asloader._get_pings(dt, limit_rowcount=1, as_dict=True)
|
||||
rows = [i for i in row_iter]
|
||||
assert rows[0] == FIXTURE_JSON
|
||||
|
||||
|
||||
def test_compute_vector(asloader, WEIGHT_VECTOR):
|
||||
"""
|
||||
Check that the vector is validated against a JSON schema
|
||||
"""
|
||||
|
||||
# Well formed vectors should pass
|
||||
vector = asloader.compute_vector_weights()
|
||||
assert vector == WEIGHT_VECTOR
|
||||
|
||||
# Full vector validation is done when we write to Remote Settings
|
|
@ -0,0 +1,82 @@
|
|||
from cfretl.models import one_cfr_model
|
||||
from cfretl.models import generate_cfr_model
|
||||
|
||||
|
||||
EXPECTED = {
|
||||
"models_by_cfr_id": {
|
||||
"CFR_ID_1": {
|
||||
"feature_1": {"p_given_cfr_acceptance": 0.7, "p_given_cfr_rejection": 0.5}
|
||||
},
|
||||
"CFR_ID_2": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.49,
|
||||
"p_given_cfr_rejection": 0.25,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.49,
|
||||
"p_given_cfr_rejection": 0.25,
|
||||
},
|
||||
},
|
||||
"CFR_ID_3": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
"feature_3": {
|
||||
"p_given_cfr_acceptance": 0.343,
|
||||
"p_given_cfr_rejection": 0.125,
|
||||
},
|
||||
},
|
||||
"CFR_ID_4": {
|
||||
"feature_1": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_2": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_3": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
"feature_4": {
|
||||
"p_given_cfr_acceptance": 0.2401,
|
||||
"p_given_cfr_rejection": 0.0625,
|
||||
},
|
||||
},
|
||||
"prior_cfr": {"p_acceptance": 0.45, "p_rejection": 0.55},
|
||||
},
|
||||
"version": 123,
|
||||
}
|
||||
|
||||
|
||||
def test_one_cfr_model():
|
||||
snip = one_cfr_model("CFR_ID", "feature_x", 0.1, 0.2)
|
||||
assert snip == {
|
||||
"CFR_ID": {
|
||||
"feature_x": {"p_given_cfr_acceptance": 0.1, "p_given_cfr_rejection": 0.2}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_generate_model():
|
||||
|
||||
data = []
|
||||
for i in range(1, 5):
|
||||
for f_i in range(1, 1 + i):
|
||||
rdict = {
|
||||
"id": "CFR_ID_%d" % i,
|
||||
"feature_id": "feature_%d" % f_i,
|
||||
"p0": 0.7 ** i,
|
||||
"p1": 0.5 ** i,
|
||||
}
|
||||
data.append(rdict)
|
||||
|
||||
model = generate_cfr_model(data, 0.45, 0.55, 123)
|
||||
|
||||
assert EXPECTED == model
|
|
@ -8,17 +8,17 @@ These tests exercise the connector to Remote Settings
|
|||
from cfretl.remote_settings import CFRRemoteSettings
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
|
||||
def _compare_weights(expected, actual):
|
||||
sorted_e_keys = sorted(expected.keys())
|
||||
sorted_a_keys = sorted(actual.keys())
|
||||
assert sorted_e_keys == sorted_a_keys
|
||||
def _compare_weights(json1, json2):
|
||||
assert json.dumps(sorted(json1), indent=2) == json.dumps(sorted(json2), indent=2)
|
||||
|
||||
sorted_e_weights = [expected[k] for k in sorted_e_keys]
|
||||
sorted_a_weights = [actual[k] for k in sorted_e_keys]
|
||||
|
||||
assert sorted_e_weights == sorted_a_weights
|
||||
@pytest.fixture
|
||||
def MOCK_CFR_DATA():
|
||||
cfr_remote = CFRRemoteSettings()
|
||||
return cfr_remote.debug_read_cfr()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
@ -32,16 +32,11 @@ def test_write_weights(WEIGHT_VECTOR):
|
|||
@pytest.mark.slow
|
||||
def test_update_weights(WEIGHT_VECTOR):
|
||||
cfr_remote = CFRRemoteSettings()
|
||||
assert cfr_remote.write_models(WEIGHT_VECTOR)
|
||||
|
||||
# Pick a key
|
||||
key = iter(WEIGHT_VECTOR.keys()).__next__()
|
||||
actual = cfr_remote._test_read_models()
|
||||
assert actual == WEIGHT_VECTOR
|
||||
|
||||
for _ in range(3):
|
||||
WEIGHT_VECTOR[key] += 1
|
||||
assert cfr_remote.write_models(WEIGHT_VECTOR)
|
||||
|
||||
actual = cfr_remote._test_read_models()
|
||||
_compare_weights(WEIGHT_VECTOR, actual)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
@ -51,12 +46,11 @@ def test_clone_into_cfr_control(MOCK_CFR_DATA):
|
|||
|
||||
actual = cfr_remote._test_read_cfr_control()
|
||||
|
||||
actual.sort(key=lambda x: x["id"])
|
||||
MOCK_CFR_DATA.sort(key=lambda x: x["id"])
|
||||
actual_ids = set([obj['id'] for obj in actual])
|
||||
expected_ids = set([obj['id'] for obj in MOCK_CFR_DATA])
|
||||
|
||||
assert len(actual) == len(MOCK_CFR_DATA)
|
||||
for a, m in zip(actual, MOCK_CFR_DATA):
|
||||
assert a["content"] == m["content"]
|
||||
diff = actual_ids.difference(expected_ids)
|
||||
assert ('panel_local_testing' in diff and len(diff) == 1) or (len(diff) == 0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
@ -64,19 +58,16 @@ def test_clone_into_cfr_experiment(MOCK_CFR_DATA):
|
|||
cfr_remote = CFRRemoteSettings()
|
||||
cfr_remote.clone_to_cfr_experiment(MOCK_CFR_DATA)
|
||||
|
||||
_actual = cfr_remote._test_read_cfr_experimental()
|
||||
actual = cfr_remote._test_read_cfr_experimental()
|
||||
|
||||
actual_target = [a for a in _actual if a["id"] == "targetting"][0]
|
||||
actual = [a for a in _actual if a["id"] != "targetting"]
|
||||
actual_ids = set([obj['id'] for obj in actual])
|
||||
expected_ids = set([obj['id'] for obj in MOCK_CFR_DATA])
|
||||
|
||||
actual.sort(key=lambda x: x["id"])
|
||||
MOCK_CFR_DATA.sort(key=lambda x: x["id"])
|
||||
diff = actual_ids.difference(expected_ids)
|
||||
|
||||
assert len(actual) == len(MOCK_CFR_DATA)
|
||||
for a, m in zip(actual, MOCK_CFR_DATA):
|
||||
assert a["content"] == m["content"]
|
||||
# Check that we have targetting added
|
||||
assert 'targetting' in actual_ids
|
||||
diff.remove('targetting')
|
||||
|
||||
assert ('panel_local_testing' in diff and len(diff) == 1) or (len(diff) == 0)
|
||||
|
||||
assert (
|
||||
actual_target["targetting"]
|
||||
== "scores.PERSONALIZED_CFR_MESSAGE > scoreThreshold" # noqa: W503
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче