зеркало из https://github.com/microsoft/SynapseML.git
Add new test cases for User Column with Strings and other datatypes in SARSpec.scala
* Add a test case for handling User Column with Strings * Add a test case for handling User Column with different datatypes * Verify the handling of User Column with Strings and other datatypes in SAR.scala * Ensure the new test cases are concise and focused on the new code * Place the new test cases in an appropriate location within the file
This commit is contained in:
Родитель
11dae039e7
Коммит
9981bee2df
|
@ -1,6 +1,3 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.recommendation
|
||||
|
||||
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject, TransformerFuzzing}
|
||||
|
@ -106,6 +103,128 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
|
|||
test("tlc test userpred jac3 userid only")(
|
||||
SarTLCSpec.testProductRecommendations(tlcSampleData, 3, "jaccard", simJac3, userAff, userpredJac3))
|
||||
|
||||
test("SAR with String User Column") {
|
||||
val stringUserCol = "stringUserId"
|
||||
val stringItemCol = "stringItemId"
|
||||
|
||||
val stringRatings: DataFrame = spark
|
||||
.createDataFrame(Seq(
|
||||
("user1", "item1", 2),
|
||||
("user1", "item3", 1),
|
||||
("user1", "item4", 5),
|
||||
("user2", "item1", 4),
|
||||
("user2", "item2", 5),
|
||||
("user2", "item3", 1),
|
||||
("user3", "item1", 4),
|
||||
("user3", "item3", 1),
|
||||
("user3", "item4", 5)
|
||||
))
|
||||
.toDF(stringUserCol, stringItemCol, ratingCol)
|
||||
.dropDuplicates()
|
||||
.cache()
|
||||
|
||||
val stringRecommendationIndexer: RecommendationIndexer = new RecommendationIndexer()
|
||||
.setUserInputCol(stringUserCol)
|
||||
.setUserOutputCol(userColIndex)
|
||||
.setItemInputCol(stringItemCol)
|
||||
.setItemOutputCol(itemColIndex)
|
||||
.setRatingCol(ratingCol)
|
||||
|
||||
val transformedStringDf: DataFrame = stringRecommendationIndexer.fit(stringRatings)
|
||||
.transform(stringRatings).cache()
|
||||
|
||||
val algo = new SAR()
|
||||
.setUserCol(stringRecommendationIndexer.getUserOutputCol)
|
||||
.setItemCol(stringRecommendationIndexer.getItemOutputCol)
|
||||
.setRatingCol(ratingCol)
|
||||
.setSupportThreshold(1)
|
||||
.setSimilarityFunction("jaccard")
|
||||
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")
|
||||
|
||||
val adapter: RankingAdapter = new RankingAdapter()
|
||||
.setK(5)
|
||||
.setRecommender(algo)
|
||||
|
||||
val recopipeline = new Pipeline()
|
||||
.setStages(Array(stringRecommendationIndexer, adapter))
|
||||
.fit(stringRatings)
|
||||
|
||||
val output = recopipeline.transform(stringRatings)
|
||||
|
||||
val evaluator: RankingEvaluator = new RankingEvaluator()
|
||||
.setK(5)
|
||||
.setNItems(10)
|
||||
|
||||
assert(evaluator.setMetricName("ndcgAt").evaluate(output) > 0.0)
|
||||
assert(evaluator.setMetricName("fcp").evaluate(output) > 0.0)
|
||||
assert(evaluator.setMetricName("mrr").evaluate(output) > 0.0)
|
||||
|
||||
val users: DataFrame = spark
|
||||
.createDataFrame(Seq(("user1", "item1"), ("user2", "item2")))
|
||||
.toDF(stringUserCol, stringItemCol)
|
||||
|
||||
val recs = recopipeline.stages(1).asInstanceOf[RankingAdapterModel].getRecommenderModel
|
||||
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
|
||||
assert(recs.count == 2)
|
||||
}
|
||||
|
||||
test("SAR with Different DataTypes in User Column") {
|
||||
val mixedUserCol = "mixedUserId"
|
||||
val mixedItemCol = "mixedItemId"
|
||||
|
||||
val mixedRatings: DataFrame = spark
|
||||
.createDataFrame(Seq(
|
||||
(1, "item1", 2),
|
||||
(1, "item3", 1),
|
||||
(1, "item4", 5),
|
||||
(2, "item1", 4),
|
||||
(2, "item2", 5),
|
||||
(2, "item3", 1),
|
||||
(3, "item1", 4),
|
||||
(3, "item3", 1),
|
||||
(3, "item4", 5),
|
||||
("user4", "item1", 3),
|
||||
("user4", "item2", 2),
|
||||
("user4", "item3", 4)
|
||||
))
|
||||
.toDF(mixedUserCol, mixedItemCol, ratingCol)
|
||||
.dropDuplicates()
|
||||
.cache()
|
||||
|
||||
val algo = new SAR()
|
||||
.setUserCol(mixedUserCol)
|
||||
.setItemCol(mixedItemCol)
|
||||
.setRatingCol(ratingCol)
|
||||
.setSupportThreshold(1)
|
||||
.setSimilarityFunction("jaccard")
|
||||
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")
|
||||
|
||||
val adapter: RankingAdapter = new RankingAdapter()
|
||||
.setK(5)
|
||||
.setRecommender(algo)
|
||||
|
||||
val recopipeline = new Pipeline()
|
||||
.setStages(Array(adapter))
|
||||
.fit(mixedRatings)
|
||||
|
||||
val output = recopipeline.transform(mixedRatings)
|
||||
|
||||
val evaluator: RankingEvaluator = new RankingEvaluator()
|
||||
.setK(5)
|
||||
.setNItems(10)
|
||||
|
||||
assert(evaluator.setMetricName("ndcgAt").evaluate(output) > 0.0)
|
||||
assert(evaluator.setMetricName("fcp").evaluate(output) > 0.0)
|
||||
assert(evaluator.setMetricName("mrr").evaluate(output) > 0.0)
|
||||
|
||||
val users: DataFrame = spark
|
||||
.createDataFrame(Seq((1, "item1"), (2, "item2"), ("user4", "item3")))
|
||||
.toDF(mixedUserCol, mixedItemCol)
|
||||
|
||||
val recs = recopipeline.stages(0).asInstanceOf[RankingAdapterModel].getRecommenderModel
|
||||
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
|
||||
assert(recs.count == 3)
|
||||
}
|
||||
}
|
||||
|
||||
class SARModelSpec extends RankingTestBase with TransformerFuzzing[SARModel] {
|
||||
|
|
Загрузка…
Ссылка в новой задаче