зеркало из https://github.com/microsoft/SynapseML.git
Add new test cases for User Column with Strings and other datatypes in SARSpec.scala
* Add a test case for handling User Column with Strings * Add a test case for handling User Column with different datatypes * Verify the handling of User Column with Strings and other datatypes in SAR.scala * Ensure the new test cases are concise and focused on the new code * Place the new test cases in an appropriate location within the file
This commit is contained in:
Родитель
11dae039e7
Коммит
9981bee2df
|
@ -1,6 +1,3 @@
|
||||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
|
||||||
|
|
||||||
package com.microsoft.azure.synapse.ml.recommendation
|
package com.microsoft.azure.synapse.ml.recommendation
|
||||||
|
|
||||||
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject, TransformerFuzzing}
|
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject, TransformerFuzzing}
|
||||||
|
@ -106,6 +103,128 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
|
||||||
test("tlc test userpred jac3 userid only")(
|
test("tlc test userpred jac3 userid only")(
|
||||||
SarTLCSpec.testProductRecommendations(tlcSampleData, 3, "jaccard", simJac3, userAff, userpredJac3))
|
SarTLCSpec.testProductRecommendations(tlcSampleData, 3, "jaccard", simJac3, userAff, userpredJac3))
|
||||||
|
|
||||||
|
test("SAR with String User Column") {
|
||||||
|
val stringUserCol = "stringUserId"
|
||||||
|
val stringItemCol = "stringItemId"
|
||||||
|
|
||||||
|
val stringRatings: DataFrame = spark
|
||||||
|
.createDataFrame(Seq(
|
||||||
|
("user1", "item1", 2),
|
||||||
|
("user1", "item3", 1),
|
||||||
|
("user1", "item4", 5),
|
||||||
|
("user2", "item1", 4),
|
||||||
|
("user2", "item2", 5),
|
||||||
|
("user2", "item3", 1),
|
||||||
|
("user3", "item1", 4),
|
||||||
|
("user3", "item3", 1),
|
||||||
|
("user3", "item4", 5)
|
||||||
|
))
|
||||||
|
.toDF(stringUserCol, stringItemCol, ratingCol)
|
||||||
|
.dropDuplicates()
|
||||||
|
.cache()
|
||||||
|
|
||||||
|
val stringRecommendationIndexer: RecommendationIndexer = new RecommendationIndexer()
|
||||||
|
.setUserInputCol(stringUserCol)
|
||||||
|
.setUserOutputCol(userColIndex)
|
||||||
|
.setItemInputCol(stringItemCol)
|
||||||
|
.setItemOutputCol(itemColIndex)
|
||||||
|
.setRatingCol(ratingCol)
|
||||||
|
|
||||||
|
val transformedStringDf: DataFrame = stringRecommendationIndexer.fit(stringRatings)
|
||||||
|
.transform(stringRatings).cache()
|
||||||
|
|
||||||
|
val algo = new SAR()
|
||||||
|
.setUserCol(stringRecommendationIndexer.getUserOutputCol)
|
||||||
|
.setItemCol(stringRecommendationIndexer.getItemOutputCol)
|
||||||
|
.setRatingCol(ratingCol)
|
||||||
|
.setSupportThreshold(1)
|
||||||
|
.setSimilarityFunction("jaccard")
|
||||||
|
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")
|
||||||
|
|
||||||
|
val adapter: RankingAdapter = new RankingAdapter()
|
||||||
|
.setK(5)
|
||||||
|
.setRecommender(algo)
|
||||||
|
|
||||||
|
val recopipeline = new Pipeline()
|
||||||
|
.setStages(Array(stringRecommendationIndexer, adapter))
|
||||||
|
.fit(stringRatings)
|
||||||
|
|
||||||
|
val output = recopipeline.transform(stringRatings)
|
||||||
|
|
||||||
|
val evaluator: RankingEvaluator = new RankingEvaluator()
|
||||||
|
.setK(5)
|
||||||
|
.setNItems(10)
|
||||||
|
|
||||||
|
assert(evaluator.setMetricName("ndcgAt").evaluate(output) > 0.0)
|
||||||
|
assert(evaluator.setMetricName("fcp").evaluate(output) > 0.0)
|
||||||
|
assert(evaluator.setMetricName("mrr").evaluate(output) > 0.0)
|
||||||
|
|
||||||
|
val users: DataFrame = spark
|
||||||
|
.createDataFrame(Seq(("user1", "item1"), ("user2", "item2")))
|
||||||
|
.toDF(stringUserCol, stringItemCol)
|
||||||
|
|
||||||
|
val recs = recopipeline.stages(1).asInstanceOf[RankingAdapterModel].getRecommenderModel
|
||||||
|
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
|
||||||
|
assert(recs.count == 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SAR with Different DataTypes in User Column") {
|
||||||
|
val mixedUserCol = "mixedUserId"
|
||||||
|
val mixedItemCol = "mixedItemId"
|
||||||
|
|
||||||
|
val mixedRatings: DataFrame = spark
|
||||||
|
.createDataFrame(Seq(
|
||||||
|
(1, "item1", 2),
|
||||||
|
(1, "item3", 1),
|
||||||
|
(1, "item4", 5),
|
||||||
|
(2, "item1", 4),
|
||||||
|
(2, "item2", 5),
|
||||||
|
(2, "item3", 1),
|
||||||
|
(3, "item1", 4),
|
||||||
|
(3, "item3", 1),
|
||||||
|
(3, "item4", 5),
|
||||||
|
("user4", "item1", 3),
|
||||||
|
("user4", "item2", 2),
|
||||||
|
("user4", "item3", 4)
|
||||||
|
))
|
||||||
|
.toDF(mixedUserCol, mixedItemCol, ratingCol)
|
||||||
|
.dropDuplicates()
|
||||||
|
.cache()
|
||||||
|
|
||||||
|
val algo = new SAR()
|
||||||
|
.setUserCol(mixedUserCol)
|
||||||
|
.setItemCol(mixedItemCol)
|
||||||
|
.setRatingCol(ratingCol)
|
||||||
|
.setSupportThreshold(1)
|
||||||
|
.setSimilarityFunction("jaccard")
|
||||||
|
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")
|
||||||
|
|
||||||
|
val adapter: RankingAdapter = new RankingAdapter()
|
||||||
|
.setK(5)
|
||||||
|
.setRecommender(algo)
|
||||||
|
|
||||||
|
val recopipeline = new Pipeline()
|
||||||
|
.setStages(Array(adapter))
|
||||||
|
.fit(mixedRatings)
|
||||||
|
|
||||||
|
val output = recopipeline.transform(mixedRatings)
|
||||||
|
|
||||||
|
val evaluator: RankingEvaluator = new RankingEvaluator()
|
||||||
|
.setK(5)
|
||||||
|
.setNItems(10)
|
||||||
|
|
||||||
|
assert(evaluator.setMetricName("ndcgAt").evaluate(output) > 0.0)
|
||||||
|
assert(evaluator.setMetricName("fcp").evaluate(output) > 0.0)
|
||||||
|
assert(evaluator.setMetricName("mrr").evaluate(output) > 0.0)
|
||||||
|
|
||||||
|
val users: DataFrame = spark
|
||||||
|
.createDataFrame(Seq((1, "item1"), (2, "item2"), ("user4", "item3")))
|
||||||
|
.toDF(mixedUserCol, mixedItemCol)
|
||||||
|
|
||||||
|
val recs = recopipeline.stages(0).asInstanceOf[RankingAdapterModel].getRecommenderModel
|
||||||
|
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
|
||||||
|
assert(recs.count == 3)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class SARModelSpec extends RankingTestBase with TransformerFuzzing[SARModel] {
|
class SARModelSpec extends RankingTestBase with TransformerFuzzing[SARModel] {
|
||||||
|
|
Загрузка…
Ссылка в новой задаче