зеркало из https://github.com/microsoft/spark.git
Java examples, tests for KMeans and ALS
- Changes ALS to accept RDD[Rating] instead of (Int, Int, Double) making it easier to call from Java - Renames class methods from `train` to `run` to enable static methods to be called from Java. - Add unit tests which check if both static / class methods can be called. - Also add examples which port the main() function in ALS, KMeans to the examples project. Couple of minor changes to existing code: - Add a toJavaRDD method in RDD to convert scala RDD to java RDD easily - Workaround a bug where using double[] from Java leads to class cast exception in KMeans init
This commit is contained in:
Родитель
d2b0f0c23d
Коммит
471fbadd0c
|
@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
|
|||
|
||||
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
|
||||
|
||||
import spark.api.java.JavaRDD
|
||||
import spark.broadcast.Broadcast
|
||||
import spark.Partitioner._
|
||||
import spark.partial.BoundedDouble
|
||||
|
@ -950,4 +951,8 @@ abstract class RDD[T: ClassManifest](
|
|||
id,
|
||||
origin)
|
||||
|
||||
def toJavaRDD() : JavaRDD[T] = {
|
||||
new JavaRDD(this)(elementClassManifest)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -118,6 +118,12 @@
|
|||
<version>${project.version}</version>
|
||||
<classifier>hadoop1</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.spark-project</groupId>
|
||||
<artifactId>spark-mllib</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<classifier>hadoop1</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-core</artifactId>
|
||||
|
@ -156,6 +162,12 @@
|
|||
<version>${project.version}</version>
|
||||
<classifier>hadoop2</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.spark-project</groupId>
|
||||
<artifactId>spark-mllib</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<classifier>hadoop2</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-core</artifactId>
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.examples;
|
||||
|
||||
import spark.api.java.JavaRDD;
|
||||
import spark.api.java.JavaSparkContext;
|
||||
import spark.api.java.function.Function;
|
||||
|
||||
import spark.mllib.recommendation.ALS;
|
||||
import spark.mllib.recommendation.MatrixFactorizationModel;
|
||||
import spark.mllib.recommendation.Rating;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.StringTokenizer;
|
||||
|
||||
import scala.Tuple2;
|
||||
|
||||
/**
|
||||
* Example using MLLib ALS from Java.
|
||||
*/
|
||||
public class JavaALS {
|
||||
|
||||
static class ParseRating extends Function<String, Rating> {
|
||||
public Rating call(String line) {
|
||||
StringTokenizer tok = new StringTokenizer(line, ",");
|
||||
Integer x = Integer.parseInt(tok.nextToken());
|
||||
Integer y = Integer.parseInt(tok.nextToken());
|
||||
Double rating = Double.parseDouble(tok.nextToken());
|
||||
return new Rating(x, y, rating);
|
||||
}
|
||||
}
|
||||
|
||||
static class FeaturesToString extends Function<Tuple2<Object, double[]>, String> {
|
||||
public String call(Tuple2<Object, double[]> element) {
|
||||
return element._1().toString() + "," + Arrays.toString(element._2());
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
if (args.length != 5 && args.length != 6) {
|
||||
System.err.println(
|
||||
"Usage: JavaALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
int rank = Integer.parseInt(args[2]);
|
||||
int iterations = Integer.parseInt(args[3]);
|
||||
String outputDir = args[4];
|
||||
int blocks = -1;
|
||||
if (args.length == 6) {
|
||||
blocks = Integer.parseInt(args[5]);
|
||||
}
|
||||
|
||||
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS",
|
||||
System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
|
||||
JavaRDD<String> lines = sc.textFile(args[1]);
|
||||
|
||||
JavaRDD<Rating> ratings = lines.map(new ParseRating());
|
||||
|
||||
MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks);
|
||||
|
||||
model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
|
||||
outputDir + "/userFeatures");
|
||||
model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
|
||||
outputDir + "/productFeatures");
|
||||
System.out.println("Final user/product features written to " + outputDir);
|
||||
|
||||
System.exit(0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.examples;
|
||||
|
||||
import spark.api.java.JavaRDD;
|
||||
import spark.api.java.JavaSparkContext;
|
||||
import spark.api.java.function.Function;
|
||||
|
||||
import spark.mllib.clustering.KMeans;
|
||||
import spark.mllib.clustering.KMeansModel;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.StringTokenizer;
|
||||
|
||||
/**
|
||||
* Example using MLLib KMeans from Java.
|
||||
*/
|
||||
public class JavaKMeans {
|
||||
|
||||
static class ParsePoint extends Function<String, double[]> {
|
||||
public double[] call(String line) {
|
||||
StringTokenizer tok = new StringTokenizer(line, " ");
|
||||
int numTokens = tok.countTokens();
|
||||
double[] point = new double[numTokens];
|
||||
for (int i = 0; i < numTokens; ++i) {
|
||||
point[i] = Double.parseDouble(tok.nextToken());
|
||||
}
|
||||
return point;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
if (args.length < 4) {
|
||||
System.err.println(
|
||||
"Usage: JavaKMeans <master> <input_file> <k> <max_iterations> [<runs>]");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
String inputFile = args[1];
|
||||
int k = Integer.parseInt(args[2]);
|
||||
int iterations = Integer.parseInt(args[3]);
|
||||
int runs = 1;
|
||||
|
||||
if (args.length >= 5) {
|
||||
runs = Integer.parseInt(args[4]);
|
||||
}
|
||||
|
||||
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
|
||||
System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
|
||||
JavaRDD<String> lines = sc.textFile(args[1]);
|
||||
|
||||
JavaRDD<double[]> points = lines.map(new ParsePoint());
|
||||
|
||||
KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs);
|
||||
|
||||
System.out.println("Cluster centers:");
|
||||
for (double[] center : model.clusterCenters()) {
|
||||
System.out.println(" " + Arrays.toString(center));
|
||||
}
|
||||
double cost = model.computeCost(points.rdd());
|
||||
System.out.println("Cost: " + cost);
|
||||
|
||||
System.exit(0);
|
||||
}
|
||||
}
|
|
@ -52,6 +52,11 @@
|
|||
<artifactId>scalacheck_${scala.version}</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.novocode</groupId>
|
||||
<artifactId>junit-interface</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
|
||||
|
|
|
@ -112,7 +112,7 @@ class KMeans private (
|
|||
* Train a K-means model on the given set of points; `data` should be cached for high
|
||||
* performance, because this is an iterative algorithm.
|
||||
*/
|
||||
def train(data: RDD[Array[Double]]): KMeansModel = {
|
||||
def run(data: RDD[Array[Double]]): KMeansModel = {
|
||||
// TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable
|
||||
|
||||
val sc = data.sparkContext
|
||||
|
@ -210,7 +210,7 @@ class KMeans private (
|
|||
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
|
||||
// Initialize each run's center to a random point
|
||||
val seed = new Random().nextInt()
|
||||
val sample = data.takeSample(true, runs, seed)
|
||||
val sample = data.takeSample(true, runs, seed).toSeq
|
||||
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
|
||||
|
||||
// On each step, sample 2 * k points on average for each run with probability proportional
|
||||
|
@ -271,7 +271,7 @@ object KMeans {
|
|||
.setMaxIterations(maxIterations)
|
||||
.setRuns(runs)
|
||||
.setInitializationMode(initializationMode)
|
||||
.train(data)
|
||||
.run(data)
|
||||
}
|
||||
|
||||
def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
package spark.mllib.recommendation
|
||||
|
||||
import java.lang.{Integer => JInt}
|
||||
import java.lang.{Double => JDouble}
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, BitSet}
|
||||
import scala.util.Random
|
||||
import scala.util.Sorting
|
||||
|
@ -55,8 +58,13 @@ private[recommendation] case class InLinkBlock(
|
|||
/**
|
||||
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
|
||||
*/
|
||||
private[recommendation] case class Rating(user: Int, product: Int, rating: Double)
|
||||
case class Rating(val user: Int, val product: Int, val rating: Double) {
|
||||
|
||||
// Constructor to build a rating from java Integers and Doubles.
|
||||
def this(user: JInt, product: JInt, rating: JDouble) = {
|
||||
this(user.intValue(), product.intValue(), rating.doubleValue())
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Alternating Least Squares matrix factorization.
|
||||
|
@ -107,7 +115,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
|
|||
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
|
||||
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
|
||||
*/
|
||||
def train(ratings: RDD[(Int, Int, Double)]): MatrixFactorizationModel = {
|
||||
def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
|
||||
val numBlocks = if (this.numBlocks == -1) {
|
||||
math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2)
|
||||
} else {
|
||||
|
@ -116,8 +124,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
|
|||
|
||||
val partitioner = new HashPartitioner(numBlocks)
|
||||
|
||||
val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
|
||||
val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
|
||||
val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) }
|
||||
val ratingsByProductBlock = ratings.map{ rating =>
|
||||
(rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating))
|
||||
}
|
||||
|
||||
val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
|
||||
val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
|
||||
|
@ -356,14 +366,14 @@ object ALS {
|
|||
* @param blocks level of parallelism to split computation into
|
||||
*/
|
||||
def train(
|
||||
ratings: RDD[(Int, Int, Double)],
|
||||
ratings: RDD[Rating],
|
||||
rank: Int,
|
||||
iterations: Int,
|
||||
lambda: Double,
|
||||
blocks: Int)
|
||||
: MatrixFactorizationModel =
|
||||
{
|
||||
new ALS(blocks, rank, iterations, lambda).train(ratings)
|
||||
new ALS(blocks, rank, iterations, lambda).run(ratings)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -378,7 +388,7 @@ object ALS {
|
|||
* @param iterations number of iterations of ALS (recommended: 10-20)
|
||||
* @param lambda regularization factor (recommended: 0.01)
|
||||
*/
|
||||
def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int, lambda: Double)
|
||||
def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
|
||||
: MatrixFactorizationModel =
|
||||
{
|
||||
train(ratings, rank, iterations, lambda, -1)
|
||||
|
@ -395,7 +405,7 @@ object ALS {
|
|||
* @param rank number of features to use
|
||||
* @param iterations number of iterations of ALS (recommended: 10-20)
|
||||
*/
|
||||
def train(ratings: RDD[(Int, Int, Double)], rank: Int, iterations: Int)
|
||||
def train(ratings: RDD[Rating], rank: Int, iterations: Int)
|
||||
: MatrixFactorizationModel =
|
||||
{
|
||||
train(ratings, rank, iterations, 0.01, -1)
|
||||
|
@ -423,7 +433,7 @@ object ALS {
|
|||
val sc = new SparkContext(master, "ALS")
|
||||
val ratings = sc.textFile(ratingsFile).map { line =>
|
||||
val fields = line.split(',')
|
||||
(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
|
||||
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
|
||||
}
|
||||
val model = ALS.train(ratings, rank, iters, 0.01, blocks)
|
||||
model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.clustering;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import spark.api.java.JavaRDD;
|
||||
import spark.api.java.JavaSparkContext;
|
||||
|
||||
public class JavaKMeansSuite implements Serializable {
|
||||
private transient JavaSparkContext sc;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
sc.stop();
|
||||
sc = null;
|
||||
System.clearProperty("spark.driver.port");
|
||||
}
|
||||
|
||||
// L1 distance between two points
|
||||
double distance1(double[] v1, double[] v2) {
|
||||
double distance = 0.0;
|
||||
for (int i = 0; i < v1.length; ++i) {
|
||||
distance = Math.max(distance, Math.abs(v1[i] - v2[i]));
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
// Assert that two sets of points are equal, within EPSILON tolerance
|
||||
void assertSetsEqual(double[][] v1, double[][] v2) {
|
||||
double EPSILON = 1e-4;
|
||||
Assert.assertTrue(v1.length == v2.length);
|
||||
for (int i = 0; i < v1.length; ++i) {
|
||||
double minDistance = Double.MAX_VALUE;
|
||||
for (int j = 0; j < v2.length; ++j) {
|
||||
minDistance = Math.min(minDistance, distance1(v1[i], v2[j]));
|
||||
}
|
||||
Assert.assertTrue(minDistance <= EPSILON);
|
||||
}
|
||||
|
||||
for (int i = 0; i < v2.length; ++i) {
|
||||
double minDistance = Double.MAX_VALUE;
|
||||
for (int j = 0; j < v1.length; ++j) {
|
||||
minDistance = Math.min(minDistance, distance1(v2[i], v1[j]));
|
||||
}
|
||||
Assert.assertTrue(minDistance <= EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void runKMeansUsingStaticMethods() {
|
||||
List<double[]> points = new ArrayList();
|
||||
points.add(new double[]{1.0, 2.0, 6.0});
|
||||
points.add(new double[]{1.0, 3.0, 0.0});
|
||||
points.add(new double[]{1.0, 4.0, 6.0});
|
||||
|
||||
double[][] expectedCenter = { {1.0, 3.0, 4.0} };
|
||||
|
||||
JavaRDD<double[]> data = sc.parallelize(points, 2);
|
||||
KMeansModel model = KMeans.train(data.rdd(), 1, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runKMeansUsingConstructor() {
|
||||
List<double[]> points = new ArrayList();
|
||||
points.add(new double[]{1.0, 2.0, 6.0});
|
||||
points.add(new double[]{1.0, 3.0, 0.0});
|
||||
points.add(new double[]{1.0, 4.0, 6.0});
|
||||
|
||||
double[][] expectedCenter = { {1.0, 3.0, 4.0} };
|
||||
|
||||
JavaRDD<double[]> data = sc.parallelize(points, 2);
|
||||
KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
|
||||
assertSetsEqual(model.clusterCenters(), expectedCenter);
|
||||
}
|
||||
}
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package spark.mllib.recommendation
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.util.Random
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
@ -27,6 +28,42 @@ import spark.SparkContext._
|
|||
|
||||
import org.jblas._
|
||||
|
||||
object ALSSuite {
|
||||
|
||||
def generateRatingsAsJavaList(
|
||||
users: Int,
|
||||
products: Int,
|
||||
features: Int,
|
||||
samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
|
||||
val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
|
||||
(seqAsJavaList(sampledRatings), trueRatings)
|
||||
}
|
||||
|
||||
def generateRatings(
|
||||
users: Int,
|
||||
products: Int,
|
||||
features: Int,
|
||||
samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
|
||||
val rand = new Random(42)
|
||||
|
||||
// Create a random matrix with uniform values from -1 to 1
|
||||
def randomMatrix(m: Int, n: Int) =
|
||||
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
|
||||
|
||||
val userMatrix = randomMatrix(users, features)
|
||||
val productMatrix = randomMatrix(features, products)
|
||||
val trueRatings = userMatrix.mmul(productMatrix)
|
||||
|
||||
val sampledRatings = {
|
||||
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
|
||||
yield Rating(u, p, trueRatings.get(u, p))
|
||||
}
|
||||
|
||||
(sampledRatings, trueRatings)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
class ALSSuite extends FunSuite with BeforeAndAfterAll {
|
||||
val sc = new SparkContext("local", "test")
|
||||
|
@ -57,21 +94,8 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
|
|||
def testALS(users: Int, products: Int, features: Int, iterations: Int,
|
||||
samplingRate: Double, matchThreshold: Double)
|
||||
{
|
||||
val rand = new Random(42)
|
||||
|
||||
// Create a random matrix with uniform values from -1 to 1
|
||||
def randomMatrix(m: Int, n: Int) =
|
||||
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
|
||||
|
||||
val userMatrix = randomMatrix(users, features)
|
||||
val productMatrix = randomMatrix(features, products)
|
||||
val trueRatings = userMatrix.mmul(productMatrix)
|
||||
|
||||
val sampledRatings = {
|
||||
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
|
||||
yield (u, p, trueRatings.get(u, p))
|
||||
}
|
||||
|
||||
val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
|
||||
features, samplingRate)
|
||||
val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
|
||||
|
||||
val predictedU = new DoubleMatrix(users, features)
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package spark.mllib.recommendation;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
import scala.Tuple2;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import spark.api.java.JavaRDD;
|
||||
import spark.api.java.JavaSparkContext;
|
||||
|
||||
import org.jblas.DoubleMatrix;
|
||||
|
||||
public class JavaALSSuite implements Serializable {
|
||||
private transient JavaSparkContext sc;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
sc.stop();
|
||||
sc = null;
|
||||
System.clearProperty("spark.driver.port");
|
||||
}
|
||||
|
||||
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
|
||||
DoubleMatrix trueRatings, double matchThreshold) {
|
||||
DoubleMatrix predictedU = new DoubleMatrix(users, features);
|
||||
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
|
||||
for (int i = 0; i < features; ++i) {
|
||||
for (scala.Tuple2<Object, double[]> userFeature : userFeatures) {
|
||||
predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
|
||||
}
|
||||
}
|
||||
DoubleMatrix predictedP = new DoubleMatrix(products, features);
|
||||
|
||||
List<scala.Tuple2<Object, double[]>> productFeatures =
|
||||
model.productFeatures().toJavaRDD().collect();
|
||||
for (int i = 0; i < features; ++i) {
|
||||
for (scala.Tuple2<Object, double[]> productFeature : productFeatures) {
|
||||
predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
|
||||
|
||||
for (int u = 0; u < users; ++u) {
|
||||
for (int p = 0; p < products; ++p) {
|
||||
double prediction = predictedRatings.get(u, p);
|
||||
double correct = trueRatings.get(u, p);
|
||||
Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runALSUsingStaticMethods() {
|
||||
int features = 1;
|
||||
int iterations = 15;
|
||||
int users = 10;
|
||||
int products = 10;
|
||||
scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
|
||||
users, products, features, 0.7);
|
||||
|
||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
||||
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
|
||||
validatePrediction(model, users, products, features, testData._2(), 0.3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void runALSUsingConstructor() {
|
||||
int features = 2;
|
||||
int iterations = 15;
|
||||
int users = 20;
|
||||
int products = 30;
|
||||
scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
|
||||
users, products, features, 0.7);
|
||||
|
||||
JavaRDD<Rating> data = sc.parallelize(testData._1());
|
||||
|
||||
MatrixFactorizationModel model = new ALS().setRank(features)
|
||||
.setIterations(iterations)
|
||||
.run(data.rdd());
|
||||
validatePrediction(model, users, products, features, testData._2(), 0.3);
|
||||
}
|
||||
}
|
|
@ -46,7 +46,7 @@ object SparkBuild extends Build {
|
|||
|
||||
lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn(bagel) dependsOn(mllib)
|
||||
|
||||
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
|
||||
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) dependsOn(mllib)
|
||||
|
||||
lazy val tools = Project("tools", file("tools"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче