253 строки
8.0 KiB
Python
253 строки
8.0 KiB
Python
import argparse
|
|
import os
|
|
from itertools import chain
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyspark.sql.functions as F
|
|
from google.cloud import bigquery
|
|
from pyspark.sql import SparkSession
|
|
from pyspark.sql.types import DoubleType
|
|
|
|
from lifetimes import BetaGeoFitter
|
|
|
|
PRED_METRICS = [
|
|
"days_seen",
|
|
"days_searched",
|
|
"days_tagged_searched",
|
|
"days_clicked_ads",
|
|
"days_searched_with_ads"]
|
|
|
|
def train_metric(d, metric, plot=True, penalty=0):
|
|
frequency = metric + "_frequency"
|
|
recency = metric + "_recency"
|
|
T = metric + "_T"
|
|
train = d
|
|
train = train[(train[frequency] > 0) & (train[recency] >= 0)]
|
|
train[frequency] = train[frequency] - 1
|
|
|
|
bgf = BetaGeoFitter(penalizer_coef=penalty)
|
|
bgf.fit(train[frequency], train[recency], train[T])
|
|
n = bgf.data.shape[0]
|
|
simulated_data = bgf.generate_new_data(size=n)
|
|
|
|
model_counts = pd.DataFrame(
|
|
bgf.data["frequency"].value_counts().sort_index().iloc[:28]
|
|
)
|
|
simulated_counts = pd.DataFrame(
|
|
simulated_data["frequency"].value_counts().sort_index().iloc[:28]
|
|
)
|
|
combined_counts = model_counts.merge(
|
|
simulated_counts, how="outer", left_index=True, right_index=True
|
|
).fillna(0)
|
|
combined_counts.columns = ["Actual", "Model"]
|
|
if plot:
|
|
combined_counts.plot.bar()
|
|
display()
|
|
return combined_counts, bgf
|
|
|
|
|
|
def catch_none(x):
|
|
if x == None:
|
|
return 0
|
|
return x
|
|
|
|
|
|
def ltv_predict(t, frequency, recency, T, model):
|
|
pred = model.conditional_expected_number_of_purchases_up_to_time(
|
|
t, catch_none(frequency), catch_none(recency), catch_none(T)
|
|
)
|
|
|
|
if pred > t:
|
|
return float(t)
|
|
return float(pred)
|
|
|
|
|
|
def main(
|
|
submission_date,
|
|
project_id,
|
|
dataset_id,
|
|
source_qualified_table_id,
|
|
intermediate_table_id,
|
|
model_input_table_id,
|
|
model_output_table_id,
|
|
temporary_gcs_bucket,
|
|
training_sample,
|
|
prediction_days,
|
|
):
|
|
"""Model the lifetime-value (LTV) of clients based on search activity.
|
|
|
|
This reads a single partition from a source table into an intermediate table
|
|
in an analysis dataset. The table is transformed for modeling. The model
|
|
inputs and predictions are then stored into separate tables in the same
|
|
dataset.
|
|
"""
|
|
print(f"Running ltv_daily job for {submission_date}")
|
|
|
|
bq = bigquery.Client()
|
|
table_ref = bq.dataset(dataset_id, project=project_id).table(intermediate_table_id)
|
|
|
|
# define the job configuration for the query
|
|
# set params and output destination for the materialized
|
|
# dataset
|
|
job_config = bigquery.QueryJobConfig()
|
|
job_config.query_parameters = [
|
|
bigquery.ScalarQueryParameter("submission_date", "STRING", submission_date)
|
|
]
|
|
job_config.destination = table_ref
|
|
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
|
|
|
|
query = f"""
|
|
SELECT
|
|
*
|
|
FROM
|
|
`{source_qualified_table_id}`
|
|
WHERE
|
|
submission_date = @submission_date
|
|
"""
|
|
query_job = bq.query(query, job_config=job_config)
|
|
query_job.result()
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
search_rfm_full = (
|
|
spark.read.format("bigquery")
|
|
.option("table", f"{project_id}.{dataset_id}.{intermediate_table_id}")
|
|
.load()
|
|
)
|
|
|
|
columns = [
|
|
[
|
|
F.col(str(metric + ".frequency")).alias(metric + "_frequency"),
|
|
F.col(str(metric + ".recency")).alias(metric + "_recency"),
|
|
F.col(str(metric + ".T")).alias(metric + "_T"),
|
|
]
|
|
for metric in PRED_METRICS
|
|
]
|
|
|
|
# flatten list
|
|
columns = [item for sublist in columns for item in sublist]
|
|
|
|
prediction_prefix = "prediction_"
|
|
p_alive_prefix = "p_alive_"
|
|
|
|
model_perf_data = pd.DataFrame()
|
|
model_pred_data = None
|
|
search_rfm_ds = search_rfm_full.limit(training_sample).select(columns).toPandas()
|
|
for metric in PRED_METRICS:
|
|
# train and extract model performace
|
|
model_perf, model = train_metric(search_rfm_ds, metric, plot=False, penalty=0.8)
|
|
model_perf["pct"] = model_perf.Model / (model_perf.Actual + 1) - 1
|
|
model_perf["metric"] = metric
|
|
model_perf["date"] = submission_date
|
|
model_perf_data = pd.concat([model_perf_data, model_perf])
|
|
|
|
# make predictions using model
|
|
@F.udf(DoubleType())
|
|
def ltv_predict_metric(metric, model=model):
|
|
import lifetimes
|
|
|
|
return ltv_predict(
|
|
prediction_days, metric.frequency, metric.recency, metric.T, model
|
|
)
|
|
|
|
@F.udf(DoubleType())
|
|
def ltv_prob_alive(metric):
|
|
import lifetimes
|
|
p_alive = float(
|
|
model.conditional_probability_alive(
|
|
catch_none(metric.frequency), catch_none(metric.recency), catch_none(metric.T)
|
|
)
|
|
)
|
|
|
|
# Lifetimes returns 1.0 if frequency==0
|
|
# https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/fitters/beta_geo_fitter.py#L293
|
|
if p_alive >= 1.0:
|
|
return 0.0
|
|
return p_alive
|
|
|
|
# go back to full sample here
|
|
predictions = search_rfm_full.select(
|
|
"*", ltv_predict_metric(metric).alias(prediction_prefix + metric), ltv_prob_alive(metric).alias(p_alive_prefix + metric)
|
|
)
|
|
|
|
if not model_pred_data:
|
|
model_pred_data = predictions
|
|
else:
|
|
model_pred_data = model_pred_data.join(
|
|
predictions.select("client_id", prediction_prefix + metric, p_alive_prefix + metric), on="client_id"
|
|
)
|
|
|
|
predictions = F.create_map(
|
|
list(
|
|
chain(
|
|
*((F.lit(name), F.col(prediction_prefix + name)) for name in PRED_METRICS)
|
|
)
|
|
)
|
|
).alias("predictions")
|
|
|
|
p_alive = F.create_map(
|
|
list(
|
|
chain(
|
|
*((F.lit(name), F.col(p_alive_prefix + name)) for name in PRED_METRICS)
|
|
)
|
|
)
|
|
).alias("p_alive")
|
|
|
|
model_perf_data["active_days"] = model_perf_data.index
|
|
model_perf_data_sdf = spark.createDataFrame(model_perf_data).withColumn(
|
|
"date", F.to_date("date")
|
|
)
|
|
|
|
ds_nodash = submission_date.replace('-', '')
|
|
(
|
|
model_perf_data_sdf.write.format("bigquery")
|
|
.option("table", f"{project_id}.{dataset_id}.{model_input_table_id}${ds_nodash}")
|
|
.option("temporaryGcsBucket", temporary_gcs_bucket)
|
|
.option("partitionField", "date")
|
|
.mode("overwrite")
|
|
.save()
|
|
)
|
|
|
|
(
|
|
model_pred_data
|
|
|
|
# Add prediction columns as maps
|
|
.withColumn("predictions", predictions)
|
|
.withColumn("p_alive", p_alive)
|
|
|
|
# Drop top-level prediction columns
|
|
.drop(*list(chain(*[[prediction_prefix + n, p_alive_prefix + n] for n in PRED_METRICS])))
|
|
|
|
# Overwrite BQ partition
|
|
.write.format("bigquery")
|
|
.option("table", f"{project_id}.{dataset_id}.{model_output_table_id}${ds_nodash}")
|
|
.option("temporaryGcsBucket", temporary_gcs_bucket)
|
|
.option("partitionField", "submission_date")
|
|
.option("clusteredFields", "sample_id,client_id")
|
|
.option("allowFieldAddition", "true")
|
|
.mode("overwrite")
|
|
.save()
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
parser.add_argument("--submission-date", help="date in YYYY-MM-DD")
|
|
parser.add_argument("--training-sample", type=int, default=500_000)
|
|
parser.add_argument("--prediction-days", type=int, default=28)
|
|
parser.add_argument("--project-id", default="moz-fx-data-bq-data-science")
|
|
parser.add_argument(
|
|
"--source-qualified-table-id",
|
|
default="moz-fx-data-shared-prod.search.search_rfm",
|
|
)
|
|
parser.add_argument("--dataset-id", default="bmiroglio")
|
|
parser.add_argument("--intermediate-table-id", default="search_rfm_day")
|
|
parser.add_argument("--model-input-table-id", default="ltv_daily_model_perf_script")
|
|
parser.add_argument("--model-output-table-id", default="ltv_daily_script")
|
|
parser.add_argument(
|
|
"--temporary-gcs-bucket", default="moz-fx-data-bq-data-science-bmiroglio"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(**vars(args))
|