Patch to add a modelBranch argument to the Federated Learning v2 study (#231)
* initial patch to add a modelBranch
The Federated Learning v2 experiment needs to have an extra parameter
passed in to distinguish between the 4 different models that need to be
trained.
The filter to select relevant pings has been modified from 76cd9a2a23
Extra test case added demonstrating filtering the aggregation step by
using the variation name
* Added filter to drop frecency pings that include the substring 'not-submitting'
* simplified the modelBranch argument to be the full string name
This commit is contained in:
Родитель
76cd9a2a23
Коммит
ed5fdc54d5
|
@ -40,21 +40,23 @@ object FederatedLearningSearchOptimizer extends StreamingJobBase {
|
|||
.select("value")
|
||||
|
||||
val query = optimize(pings,
|
||||
opts.checkpointPath(), opts.modelOutputBucket(), opts.modelOutputKey(), opts.stateCheckpointPath(),
|
||||
opts.stateBootstrapFilePath.get, Clock.systemUTC(), opts.windowOffsetMinutes(), opts.raiseOnError(),
|
||||
opts.s3EndpointOverride.get)
|
||||
opts.checkpointPath(), opts.modelOutputBucket(),
|
||||
opts.modelOutputKey(), opts.modelBranch(),
|
||||
opts.stateCheckpointPath(), opts.stateBootstrapFilePath.get,
|
||||
Clock.systemUTC(), opts.windowOffsetMinutes(),
|
||||
opts.raiseOnError(), opts.s3EndpointOverride.get)
|
||||
|
||||
query.awaitTermination()
|
||||
}
|
||||
|
||||
def optimize(pings: DataFrame, checkpointPath: String, modelOutputBucket: String, modelOutputKey: String, // scalastyle:ignore
|
||||
stateCheckpointPath: String, stateBootstrapFilePath: Option[String] = None,
|
||||
modelBranch: String, stateCheckpointPath: String, stateBootstrapFilePath: Option[String] = None,
|
||||
clock: Clock, windowOffsetMin: Int, raiseOnError: Boolean = false, s3EndpointOverride: Option[String] = None): StreamingQuery = {
|
||||
val aggregates = aggregate(pings, clock, windowOffsetMin, raiseOnError)
|
||||
val aggregates = aggregate(pings, modelBranch, clock, windowOffsetMin, raiseOnError)
|
||||
writeUpdates(aggregates, checkpointPath, modelOutputBucket, modelOutputKey, stateCheckpointPath, stateBootstrapFilePath, s3EndpointOverride)
|
||||
}
|
||||
|
||||
def aggregate(pings: DataFrame, clock: Clock, windowOffsetMin: Int, raiseOnError: Boolean = false): Dataset[FrecencyUpdateAggregate] = {
|
||||
def aggregate(pings: DataFrame, modelBranch: String, clock: Clock, windowOffsetMin: Int, raiseOnError: Boolean = false): Dataset[FrecencyUpdateAggregate] = {
|
||||
import pings.sparkSession.implicits._
|
||||
|
||||
val frecencyUpdates: Dataset[FrecencyUpdate] = pings.flatMap { v =>
|
||||
|
@ -64,7 +66,13 @@ object FederatedLearningSearchOptimizer extends StreamingJobBase {
|
|||
val docType = fields.getOrElse("docType", "").asInstanceOf[String]
|
||||
if ("frecency-update" == docType) {
|
||||
val ping = FrecencyUpdatePing(m)
|
||||
if ((ping.payload.study_variation contains "training") && (ping.payload.bookmark_and_history_num_suggestions_displayed > -1)) {
|
||||
if (
|
||||
(
|
||||
(ping.payload.study_variation startsWith modelBranch) &&
|
||||
!(ping.payload.study_variation contains "not-submitting")
|
||||
) &&
|
||||
( ping.payload.bookmark_and_history_num_suggestions_displayed > -1)
|
||||
) {
|
||||
Option(FrecencyUpdate(
|
||||
new Timestamp(clock.millis()),
|
||||
ping.payload.model_version,
|
||||
|
@ -98,7 +106,9 @@ object FederatedLearningSearchOptimizer extends StreamingJobBase {
|
|||
}
|
||||
|
||||
def writeUpdates(aggregates: Dataset[FrecencyUpdateAggregate], checkpointPath: String, modelOutputBucket: String, modelOutputKey: String,
|
||||
stateCheckpointPath: String, stateBootstrapFilePath: Option[String], s3EndpointOverride: Option[String] = None): StreamingQuery = {
|
||||
stateCheckpointPath: String,
|
||||
stateBootstrapFilePath: Option[String], s3EndpointOverride:
|
||||
Option[String] = None): StreamingQuery = {
|
||||
val writer = aggregates.writeStream
|
||||
.format("com.mozilla.telemetry.learning.federated.FederatedLearningSearchOptimizerS3SinkProvider")
|
||||
.option("checkpointLocation", checkpointPath)
|
||||
|
@ -129,6 +139,10 @@ object FederatedLearningSearchOptimizer extends StreamingJobBase {
|
|||
name = "modelOutputKey",
|
||||
descr = "S3 key to save public model iterations",
|
||||
required = true)
|
||||
val modelBranch: ScallopOption[String] = opt[String](
|
||||
name = "modelBranch",
|
||||
descr = "Experiment model branch that we are going to be updating",
|
||||
required = true)
|
||||
val stateCheckpointPath: ScallopOption[String] = opt[String](
|
||||
name = "stateCheckpointPath",
|
||||
descr = "Location to save model optimizer state",
|
||||
|
|
|
@ -512,7 +512,8 @@ object TestUtils {
|
|||
|
||||
def generateFrecencyUpdateMessages(size: Int,
|
||||
fieldsOverride: Option[Map[String, Any]] = None,
|
||||
timestamp: Option[Long] = None): Seq[Message] = {
|
||||
timestamp: Option[Long] = None,
|
||||
modelBranch: String = "model1"): Seq[Message] = {
|
||||
val defaultMap = Map(
|
||||
"clientId" -> "client1",
|
||||
"docType" -> "frecency-update",
|
||||
|
@ -573,7 +574,7 @@ object TestUtils {
|
|||
| "selected_style": "autofill heuristic",
|
||||
| "selected_url_was_same_as_search_string": 0,
|
||||
| "enter_was_pressed": 1,
|
||||
| "study_variation": "training",
|
||||
| "study_variation": "${modelBranch}",
|
||||
| "study_addon_version": "2.1.1"
|
||||
""".stripMargin
|
||||
1.to(size) map { index =>
|
||||
|
|
|
@ -35,8 +35,38 @@ class FederatedLearningSearchOptimizerTest extends FlatSpec with Matchers with G
|
|||
val pingsStream = MemoryStream[Array[Byte]]
|
||||
|
||||
When("they're aggregated")
|
||||
val query = FederatedLearningSearchOptimizer.aggregate(pingsStream.toDF(), clock, 28)
|
||||
val query = FederatedLearningSearchOptimizer.aggregate(pingsStream.toDF(), "model1", clock, 28)
|
||||
.writeStream.format("memory").queryName("updates").start()
|
||||
pingsStream.addData(pings)
|
||||
query.processAllAvailable()
|
||||
|
||||
clock.advance(TimeUnit.MINUTES.toNanos(45))
|
||||
|
||||
pingsStream.addData(TestUtils.generateFrecencyUpdateMessages(5,
|
||||
timestamp = Some(TestUtils.testTimestampNano + TimeUnit.MINUTES.toNanos(45))).map(_.toByteArray).seq)
|
||||
query.processAllAvailable()
|
||||
pingsStream.addData(Array[Byte]())
|
||||
query.processAllAvailable()
|
||||
query.stop()
|
||||
|
||||
Then("a set of aggregates is produced")
|
||||
val res = spark.sql("select * from updates").as[FrecencyUpdateAggregate]
|
||||
|
||||
res.show(false)
|
||||
res.count shouldBe 1
|
||||
}
|
||||
|
||||
"Federated learning Optimizer" should "ignore pings for a different model" in {
|
||||
import spark.implicits._
|
||||
|
||||
Given("set of frecency update pings")
|
||||
val messages = TestUtils.generateFrecencyUpdateMessages(10, modelBranch="model2")
|
||||
val pings = messages.map(_.toByteArray)
|
||||
val pingsStream = MemoryStream[Array[Byte]]
|
||||
|
||||
When("they're aggregated")
|
||||
val query = FederatedLearningSearchOptimizer.aggregate(pingsStream.toDF(), "model1", clock, 28)
|
||||
.writeStream.format("memory").queryName("updates_ignored").start()
|
||||
pingsStream.addData(pings)
|
||||
query.processAllAvailable()
|
||||
|
||||
|
@ -49,13 +79,14 @@ class FederatedLearningSearchOptimizerTest extends FlatSpec with Matchers with G
|
|||
query.processAllAvailable()
|
||||
query.stop()
|
||||
|
||||
Then("a set of aggregates is produced")
|
||||
val res = spark.sql("select * from updates").as[FrecencyUpdateAggregate]
|
||||
Then("a empty set of aggregates is produced")
|
||||
val res = spark.sql("select * from updates_ignored").as[FrecencyUpdateAggregate]
|
||||
|
||||
res.show(false)
|
||||
res.count shouldBe 1
|
||||
res.count shouldBe 0
|
||||
}
|
||||
|
||||
|
||||
it should "optimize weight updates and save model" in {
|
||||
import spark.implicits._
|
||||
|
||||
|
@ -68,7 +99,7 @@ class FederatedLearningSearchOptimizerTest extends FlatSpec with Matchers with G
|
|||
|
||||
val s3 = S3TestUtil(MockEndpointPort, Some(OutputBucket))
|
||||
val query = FederatedLearningSearchOptimizer.optimize(pingsStream.toDF(), CheckpointPath + "/spark", OutputBucket,
|
||||
OutputKey, CheckpointPath, None, clock, 28, s3EndpointOverride = Some(MockEndpoint))
|
||||
OutputKey, "model1", CheckpointPath, None, clock, 28, s3EndpointOverride = Some(MockEndpoint))
|
||||
pingsStream.addData(pings)
|
||||
query.processAllAvailable()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче