Merge branch 'master' into streaming

Conflicts:
	core/src/main/scala/spark/rdd/CoGroupedRDD.scala
	core/src/main/scala/spark/rdd/FilteredRDD.scala
	docs/_layouts/global.html
	docs/index.md
	run
This commit is contained in:
Tathagata Das 2013-01-15 12:08:51 -08:00
Родитель 1638fcb0dc cb867e9ffb
Коммит cd1521cfdb
69 изменённых файлов: 3500 добавлений и 294 удалений

Просмотреть файл

@ -45,6 +45,11 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>
@ -72,6 +77,12 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>

Просмотреть файл

@ -71,6 +71,10 @@
<groupId>cc.spray</groupId> <groupId>cc.spray</groupId>
<artifactId>spray-server</artifactId> <artifactId>spray-server</artifactId>
</dependency> </dependency>
<dependency>
<groupId>cc.spray</groupId>
<artifactId>spray-json_${scala.version}</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.tomdz.twirl</groupId> <groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId> <artifactId>twirl-api</artifactId>
@ -159,6 +163,11 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
@ -211,6 +220,12 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
@ -267,4 +282,4 @@
</build> </build>
</profile> </profile>
</profiles> </profiles>
</project> </project>

Просмотреть файл

@ -38,20 +38,37 @@ class Accumulable[R, T] (
*/ */
def += (term: T) { value_ = param.addAccumulator(value_, term) } def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def add(term: T) { value_ = param.addAccumulator(value_, term) }
/** /**
* Merge two accumulable objects together * Merge two accumulable objects together
* *
* Normally, a user will not want to use this version, but will instead call `+=`. * Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other Accumulable that will get merged with this * @param term the other `R` that will get merged with this
*/ */
def ++= (term: R) { value_ = param.addInPlace(value_, term)} def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
def merge(term: R) { value_ = param.addInPlace(value_, term)}
/** /**
* Access the accumulator's current value; only allowed on master. * Access the accumulator's current value; only allowed on master.
*/ */
def value = { def value: R = {
if (!deserialized) value_ if (!deserialized) {
else throw new UnsupportedOperationException("Can't read accumulator value in task") value_
} else {
throw new UnsupportedOperationException("Can't read accumulator value in task")
}
} }
/** /**
@ -68,10 +85,17 @@ class Accumulable[R, T] (
/** /**
* Set the accumulator's value; only allowed on master. * Set the accumulator's value; only allowed on master.
*/ */
def value_= (r: R) { def value_= (newValue: R) {
if (!deserialized) value_ = r if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task") else throw new UnsupportedOperationException("Can't assign accumulator value in task")
} }
/**
* Set the accumulator's value; only allowed on master
*/
def setValue(newValue: R) {
this.value = newValue
}
// Called by Java when deserializing an object // Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) { private def readObject(in: ObjectInputStream) {

Просмотреть файл

@ -1,118 +0,0 @@
package spark
import java.util.LinkedHashMap
/**
* An implementation of Cache that estimates the sizes of its entries and attempts to limit its
* total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using
* SizeEstimator, which has limitations; most notably, we will overestimate total memory used if
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
this(BoundedMemoryCache.getMaxBytes)
}
private var currentBytes = 0L
private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
override def get(datasetId: Any, partition: Int): Any = {
synchronized {
val entry = map.get((datasetId, partition))
if (entry != null) {
entry.value
} else {
null
}
}
}
override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
val key = (datasetId, partition)
logInfo("Asked to add key " + key)
val size = estimateValueSize(key, value)
synchronized {
if (size > getCapacity) {
return CachePutFailure()
} else if (ensureFreeSpace(datasetId, size)) {
logInfo("Adding key " + key)
map.put(key, new Entry(value, size))
currentBytes += size
logInfo("Number of entries is now " + map.size)
return CachePutSuccess(size)
} else {
logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
return CachePutFailure()
}
}
}
override def getCapacity: Long = maxBytes
/**
* Estimate sizeOf 'value'
*/
private def estimateValueSize(key: (Any, Int), value: Any) = {
val startTime = System.currentTimeMillis
val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
val timeTaken = System.currentTimeMillis - startTime
logInfo("Estimated size for key %s is %d".format(key, size))
logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
size
}
/**
* Remove least recently used entries from the map until at least space bytes are free, in order
* to make space for a partition from the given dataset ID. If this cannot be done without
* evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
* that a lock is held on the BoundedMemoryCache.
*/
private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
datasetId, space, currentBytes, maxBytes))
val iter = map.entrySet.iterator // Will give entries in LRU order
while (maxBytes - currentBytes < space && iter.hasNext) {
val mapEntry = iter.next()
val (entryDatasetId, entryPartition) = mapEntry.getKey
if (entryDatasetId == datasetId) {
// Cannot make space without removing part of the same dataset, or a more recently used one
return false
}
reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
currentBytes -= mapEntry.getValue.size
iter.remove()
}
return true
}
protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
// TODO: remove BoundedMemoryCache
val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)]
innerDatasetId match {
case rddId: Int =>
SparkEnv.get.cacheTracker.dropEntry(rddId, partition)
case broadcastUUID: java.util.UUID =>
// TODO: Maybe something should be done if the broadcasted variable falls out of cache
case _ =>
}
}
}
// An entry in our map; stores a cached object and its size in bytes
private[spark] case class Entry(value: Any, size: Long)
private[spark] object BoundedMemoryCache {
/**
* Get maximum cache capacity from system configuration
*/
def getMaxBytes: Long = {
val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
}
}

Просмотреть файл

@ -615,6 +615,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup() writer.cleanup()
} }
/**
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
/**
* Return an RDD with the values of each tuple.
*/
def values: RDD[V] = self.map(_._2)
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure

Просмотреть файл

@ -348,6 +348,13 @@ abstract class RDD[T: ClassManifest](
*/ */
def toArray(): Array[T] = collect() def toArray(): Array[T] = collect()
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
filter(f.isDefinedAt).map(f)
}
/** /**
* Reduces the elements of this RDD using the specified associative binary operator. * Reduces the elements of this RDD using the specified associative binary operator.
*/ */
@ -529,6 +536,13 @@ abstract class RDD[T: ClassManifest](
.saveAsSequenceFile(path) .saveAsSequenceFile(path)
} }
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: T => K): RDD[(K, T)] = {
map(x => (f(x), x))
}
/** A private method for tests, to look at the contents of each partition */ /** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = { private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray) sc.runJob(this, (iter: Iterator[T]) => iter.toArray)

Просмотреть файл

@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure classManifest[T].erasure
} else { } else {
implicitly[T => Writable].getClass.getMethods()(0).getReturnType // We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
// is of the form "java.lang.Object apply(java.lang.Object)"
implicitly[T => Writable].getClass.getDeclaredMethods().filter(
m => m.getReturnType().toString != "java.lang.Object" &&
m.getName() == "apply")(0).getReturnType
} }
// TODO: use something like WritableConverter to avoid reflection // TODO: use something like WritableConverter to avoid reflection
} }

Просмотреть файл

@ -9,7 +9,6 @@ import java.util.Random
import javax.management.MBeanServer import javax.management.MBeanServer
import java.lang.management.ManagementFactory import java.lang.management.ManagementFactory
import com.sun.management.HotSpotDiagnosticMXBean
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging {
if (System.getProperty("spark.test.useCompressedOops") != null) { if (System.getProperty("spark.test.useCompressedOops") != null) {
return System.getProperty("spark.test.useCompressedOops").toBoolean return System.getProperty("spark.test.useCompressedOops").toBoolean
} }
try { try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer() val server = ManagementFactory.getPlatformMBeanServer()
// NOTE: This should throw an exception in non-Sun JVMs
val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
Class.forName("java.lang.String"))
val bean = ManagementFactory.newPlatformMXBeanProxy(server, val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) hotSpotMBeanName, hotSpotMBeanClass)
return bean.getVMOption("UseCompressedOops").getValue.toBoolean // TODO: We could use reflection on the VMOption returned ?
return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch { } catch {
case e: Exception => { case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB

Просмотреть файл

@ -396,11 +396,12 @@ class SparkContext(
new Accumulator(initialValue, param) new Accumulator(initialValue, param)
/** /**
* Create an [[spark.Accumulable]] shared variable, with a `+=` method * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
* Only the master can access the accumuable's `value`.
* @tparam T accumulator type * @tparam T accumulator type
* @tparam R type that can be added to the accumulator * @tparam R type that can be added to the accumulator
*/ */
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param) new Accumulable(initialValue, param)
/** /**
@ -418,7 +419,7 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once. * reading it in distributed functions. The variable will be sent to each cluster only once.
*/ */
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/** /**
* Add a file to be downloaded into the working directory of this Spark job on every node. * Add a file to be downloaded into the working directory of this Spark job on every node.

Просмотреть файл

@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending)) fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
} }
/**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
/**
* Return an RDD with the values of each tuple.
*/
def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
} }
object JavaPairRDD { object JavaPairRDD {

Просмотреть файл

@ -298,4 +298,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Save this RDD as a SequenceFile of serialized objects. * Save this RDD as a SequenceFile of serialized objects.
*/ */
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
} }

Просмотреть файл

@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast import spark.broadcast.Broadcast
@ -265,25 +265,45 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** /**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`. * to using the `add` method. Only the master can access the accumulator's `value`.
*/ */
def intAccumulator(initialValue: Int): Accumulator[Int] = def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
sc.accumulator(initialValue)(IntAccumulatorParam) sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/** /**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values * Create an [[spark.Accumulator]] double variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`. * to using the `add` method. Only the master can access the accumulator's `value`.
*/ */
def doubleAccumulator(initialValue: Double): Accumulator[Double] = def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam) sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
doubleAccumulator(initialValue)
/** /**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`. * to using the `add` method. Only the master can access the accumulator's `value`.
*/ */
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam) sc.accumulator(initialValue)(accumulatorParam)
/**
* Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
* "add" values with `add`. Only the master can access the accumuable's `value`.
*/
def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
sc.accumulable(initialValue)(param)
/** /**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once. * reading it in distributed functions. The variable will be sent to each cluster only once.

Просмотреть файл

@ -0,0 +1,39 @@
package spark.api.python
import spark.Partitioner
import java.util.Arrays
/**
* A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*/
private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
if (key == null) {
return 0
}
else {
val hashCode = {
if (key.isInstanceOf[Array[Byte]]) {
Arrays.hashCode(key.asInstanceOf[Array[Byte]])
} else {
key.hashCode()
}
}
val mod = hashCode % numPartitions
if (mod < 0) {
mod + numPartitions
} else {
mod // Guard against negative hash codes
}
}
}
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
}

Просмотреть файл

@ -0,0 +1,247 @@
package spark.api.python
import java.io._
import java.util.{List => JList}
import scala.collection.JavaConversions._
import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
import java.util
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars)
override def getSplits = parent.splits
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
for ((variable, value) <- envVars) {
currentEnvVars.put(variable, value)
}
val proc = pb.start()
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
new Thread("stderr reader for " + command) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
// Split index
dOut.writeInt(split.index)
// Broadcast variables
dOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dOut.writeLong(broadcast.id)
dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value)
dOut.flush()
}
// Serialized user code
for (elem <- command) {
out.println(elem)
}
out.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dOut)
}
dOut.flush()
out.flush()
proc.getOutputStream.close()
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
def next() = {
val obj = _nextObj
_nextObj = read()
obj
}
private def read() = {
try {
val length = stream.readInt()
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
new Array[Byte](0)
}
case e => throw e
}
}
var _nextObj = read()
def hasNext = _nextObj.length != 0
}
}
override def checkpoint() { }
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
override def getSplits = prev.splits
override def compute(split: Split, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
private[spark] object PythonRDD {
/** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
def stripPickle(arr: Array[Byte]) : Array[Byte] = {
arr.slice(2, arr.length - 1)
}
/**
* Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
* The data format is a 32-bit integer representing the pickled object's length (in bytes),
* followed by the pickled data.
*
* Pickle module:
*
* http://docs.python.org/2/library/pickle.html
*
* The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
*
* http://hg.python.org/cpython/file/2.6/Lib/pickle.py
* http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
*
* @param elem the object to write
* @param dOut a data output stream
*/
def writeAsPickle(elem: Any, dOut: DataOutputStream) {
if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]]
dOut.writeInt(arr.length)
dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new Exception("Unexpected RDD type")
}
}
def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
objs.append(obj)
}
} catch {
case eof: EOFException => {}
case e => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
}
file.close()
}
def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
}
private object Pickle {
val PROTO: Byte = 0x80.toByte
val TWO: Byte = 0x02.toByte
val BINUNICODE: Byte = 'X'
val STOP: Byte = '.'
val TUPLE2: Byte = 0x86.toByte
val EMPTY_LIST: Byte = ']'
val MARK: Byte = '('
val APPENDS: Byte = 'e'
}
private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
Array[Byte]), Array[Byte]] {
override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
}
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}

Просмотреть файл

@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._ import spark._
abstract class Broadcast[T](id: Long) extends Serializable { abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T def value: T
// We cannot have an abstract readObject here due to some weird issues with // We cannot have an abstract readObject here due to some weird issues with

Просмотреть файл

@ -0,0 +1,78 @@
package spark.deploy
import master.{JobInfo, WorkerInfo}
import worker.ExecutorRunner
import cc.spray.json._
/**
* spray-json helper class containing implicit conversion to json for marshalling responses
*/
private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
"memory" -> JsNumber(obj.memory),
"memoryused" -> JsNumber(obj.memoryUsed)
)
}
implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
def write(obj: JobInfo) = JsObject(
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
"cores" -> JsNumber(obj.desc.cores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
}
implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
def write(obj: JobDescription) = JsObject(
"name" -> JsString(obj.name),
"cores" -> JsNumber(obj.cores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)
}
implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] {
def write(obj: ExecutorRunner) = JsObject(
"id" -> JsNumber(obj.execId),
"memory" -> JsNumber(obj.memory),
"jobid" -> JsString(obj.jobId),
"jobdesc" -> obj.jobDesc.toJson.asJsObject
)
}
implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] {
def write(obj: MasterState) = JsObject(
"url" -> JsString("spark://" + obj.uri),
"workers" -> JsArray(obj.workers.toList.map(_.toJson)),
"cores" -> JsNumber(obj.workers.map(_.cores).sum),
"coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
"memory" -> JsNumber(obj.workers.map(_.memory).sum),
"memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
"activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
"completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
)
}
implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] {
def write(obj: WorkerState) = JsObject(
"id" -> JsString(obj.workerId),
"masterurl" -> JsString(obj.masterUrl),
"masterwebuiurl" -> JsString(obj.masterWebUiUrl),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
"memory" -> JsNumber(obj.memory),
"memoryused" -> JsNumber(obj.memoryUsed),
"executors" -> JsArray(obj.executors.toList.map(_.toJson)),
"finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson))
)
}
}

Просмотреть файл

@ -8,7 +8,11 @@ import akka.util.duration._
import cc.spray.Directives import cc.spray.Directives
import cc.spray.directives._ import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._ import cc.spray.typeconversion.TwirlSupport._
import cc.spray.http.MediaTypes
import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy._ import spark.deploy._
import spark.deploy.JsonProtocol._
private[spark] private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
@ -19,29 +23,51 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val handler = { val handler = {
get { get {
path("") { (path("") & parameters('format ?)) {
completeWith { case Some(js) if js.equalsIgnoreCase("json") =>
val future = master ? RequestMasterState val future = master ? RequestMasterState
future.map { respondWithMediaType(MediaTypes.`application/json`) { ctx =>
masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) ctx.complete(future.mapTo[MasterState])
} }
} case _ =>
} ~
path("job") {
parameter("jobId") { jobId =>
completeWith { completeWith {
val future = master ? RequestMasterState val future = master ? RequestMasterState
future.map { state => future.map {
val masterState = state.asInstanceOf[MasterState] masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
// A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference.
(masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
} }
} }
} ~
path("job") {
parameters("jobId", 'format ?) {
case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
masterState.activeJobs.find(_.id == jobId) match {
case Some(job) => job
case _ => masterState.completedJobs.find(_.id == jobId) match {
case Some(job) => job
case _ => null
}
}
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
ctx.complete(jobInfo.mapTo[JobInfo])
}
case (jobId, _) =>
completeWith {
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
masterState.activeJobs.find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => masterState.completedJobs.find(_.id == jobId) match {
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
}
}
}
} }
} ~ } ~
pathPrefix("static") { pathPrefix("static") {

Просмотреть файл

@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) {
} }
def inferDefaultMemory(): Int = { def inferDefaultMemory(): Int = {
val bean = ManagementFactory.getOperatingSystemMXBean val ibmVendor = System.getProperty("java.vendor").contains("IBM")
.asInstanceOf[com.sun.management.OperatingSystemMXBean] var totalMb = 0
val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt try {
val bean = ManagementFactory.getOperatingSystemMXBean()
if (ibmVendor) {
val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
} else {
val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean")
val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
}
} catch {
case e: Exception => {
totalMb = 2*1024
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
}
}
// Leave out 1 GB for the operating system, but don't return a negative memory size // Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512) math.max(totalMb - 1024, 512)
} }

Просмотреть файл

@ -7,7 +7,11 @@ import akka.util.Timeout
import akka.util.duration._ import akka.util.duration._
import cc.spray.Directives import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._ import cc.spray.typeconversion.TwirlSupport._
import cc.spray.http.MediaTypes
import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._
private[spark] private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
@ -18,13 +22,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
val handler = { val handler = {
get { get {
path("") { (path("") & parameters('format ?)) {
completeWith{ case Some(js) if js.equalsIgnoreCase("json") => {
val future = worker ? RequestWorkerState val future = worker ? RequestWorkerState
future.map { workerState => respondWithMediaType(MediaTypes.`application/json`) { ctx =>
spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) ctx.complete(future.mapTo[WorkerState])
} }
} }
case _ =>
completeWith{
val future = worker ? RequestWorkerState
future.map { workerState =>
spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
}
}
} ~ } ~
path("log") { path("log") {
parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) =>

Просмотреть файл

@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize) val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) { if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo messages += message // this is probably incorrect, it wont work as fifo
if (!message.started) logDebug("Starting to send [" + message + "]") if (!message.started) {
message.started = true logDebug("Starting to send [" + message + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
return chunk return chunk
} else { } else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/

Просмотреть файл

@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
} }
val selector = SelectorProvider.provider.openSelector() val selector = SelectorProvider.provider.openSelector()
val handleMessageExecutor = Executors.newFixedThreadPool(4) val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open() val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus] val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new SynchronizedQueue[SendingConnection] val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)] val sendMessageRequests = new Queue[(Message, SendingConnection)]
@ -79,10 +79,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() { def run() {
try { try {
while(!selectorThread.isInterrupted) { while(!selectorThread.isInterrupted) {
while(!connectionRequests.isEmpty) { for( (connectionManagerId, sendingConnection) <- connectionRequests) {
val sendingConnection = connectionRequests.dequeue
sendingConnection.connect() sendingConnection.connect()
addConnection(sendingConnection) addConnection(sendingConnection)
connectionRequests -= connectionManagerId
} }
sendMessageRequests.synchronized { sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) { while(!sendMessageRequests.isEmpty) {
@ -300,8 +300,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = { def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = new SendingConnection(inetSocketAddress, selector) val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
connectionRequests += newConnection
newConnection newConnection
} }
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
@ -473,6 +472,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0 val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime val ms = finishTime - startTime
val tput = mb * 1000.0 / ms val tput = mb * 1000.0 / ms
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------") println("--------------------------")
println() println()
} }

Просмотреть файл

@ -13,8 +13,14 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{ private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) { def main(args: Array[String]) {
//<mesos cluster> - the master URL
//<slaves file> - a list slaves to run connectionTest on
//[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
//[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
//[count] - how many times to run, default is 3
//[await time in seconds] : await time (in seconds), default is 600
if (args.length < 2) { if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1) System.exit(1)
} }
@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/ /*println("Slaves")*/
/*slaves.foreach(println)*/ /*slaves.foreach(println)*/
val tasknum = if (args.length > 2) args(2).toInt else slaves.length
val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
val count = if (args.length > 4) args(4).toInt else 3
val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect() i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds") println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println) slaveConnManagerIds.foreach(println)
println println
val count = 10
(0 until count).foreach(i => { (0 until count).foreach(i => {
val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{
None None
}) })
val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip buffer.flip
@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
}) })
val results = futures.map(f => Await.result(f, 1.second)) val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis val finishTime = System.currentTimeMillis
Thread.sleep(5000) Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0 val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime val ms = finishTime - startTime
val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr) logInfo(resultStr)
resultStr resultStr
}).collect() }).collect()

Просмотреть файл

@ -1,9 +1,9 @@
package spark.rdd package spark.rdd
import java.io.{ObjectOutputStream, IOException} import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency} import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@ -86,9 +86,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit] val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]] val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = { def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
} }
for ((dep, depNum) <- split.deps.zipWithIndex) dep match { for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
@ -108,7 +115,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
} }
} }
map.iterator JavaConversions.mapAsScalaMap(map).iterator
} }
override def clearDependencies() { override def clearDependencies() {

Просмотреть файл

@ -9,6 +9,8 @@ private[spark] class FilteredRDD[T: ClassManifest](
override def getSplits = firstParent[T].splits override def getSplits = firstParent[T].splits
override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
override def compute(split: Split, context: TaskContext) = override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).filter(f) firstParent[T].iterator(split, context).filter(f)
} }

Просмотреть файл

@ -201,7 +201,11 @@ private[spark] class TaskSetManager(
val taskId = sched.newTaskId() val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch // Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host) val preferred = isPreferredLocation(task, host)
val prefStr = if (preferred) "preferred" else "non-preferred" val prefStr = if (preferred) {
"preferred"
} else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
}
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, slaveId, host, prefStr)) taskSet.id, index, taskId, slaveId, host, prefStr))
// Do various bookkeeping // Do various bookkeeping

Просмотреть файл

@ -1,58 +0,0 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.PrivateMethodTester
import org.scalatest.matchers.ShouldMatchers
// TODO: Replace this with a test of MemoryStore
class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers {
test("constructor test") {
val cache = new BoundedMemoryCache(60)
expect(60)(cache.getCapacity)
}
test("caching") {
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
val oldArch = System.setProperty("os.arch", "amd64")
val oldOops = System.setProperty("spark.test.useCompressedOops", "true")
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
val cache = new BoundedMemoryCache(60) {
//TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
}
}
// NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length
// This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
// http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
// Work around to check for either.
//should be OK
cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
//we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
//cache because it's from the same dataset
expect(CachePutFailure())(cache.put("1", 1, "Meh"))
//should be OK, dataset '1' can be evicted from cache
cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
//should fail, cache should obey it's capacity
expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
if (oldArch != null) {
System.setProperty("os.arch", oldArch)
} else {
System.clearProperty("os.arch")
}
if (oldOops != null) {
System.setProperty("spark.test.useCompressedOops", oldOops)
} else {
System.clearProperty("spark.test.useCompressedOops")
}
}
}

Просмотреть файл

@ -581,4 +581,64 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles); JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
zipped.count(); zipped.count();
} }
@Test
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
final Accumulator<Integer> intAccum = sc.accumulator(10);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
intAccum.add(x);
}
});
Assert.assertEquals((Integer) 25, intAccum.value());
final Accumulator<Double> doubleAccum = sc.accumulator(10.0);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
doubleAccum.add((double) x);
}
});
Assert.assertEquals((Double) 25.0, doubleAccum.value());
// Try a custom accumulator type
AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
public Float addInPlace(Float r, Float t) {
return r + t;
}
public Float addAccumulator(Float r, Float t) {
return r + t;
}
public Float zero(Float initialValue) {
return 0.0f;
}
};
final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
floatAccum.add((float) x);
}
});
Assert.assertEquals((Float) 25.0f, floatAccum.value());
// Test the setValue method
floatAccum.setValue(5.0f);
Assert.assertEquals((Float) 5.0f, floatAccum.value());
}
@Test
public void keyBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
public String call(Integer t) throws Exception {
return t.toString();
}
}).collect();
Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
}
} }

Просмотреть файл

@ -106,6 +106,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.map(_ => 1).partitioner === None)
assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
} }
test("partitioning Java arrays should fail") { test("partitioning Java arrays should fail") {

Просмотреть файл

@ -35,6 +35,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7)) assert(partitionSums.collect().toList === List(3, 7))

Просмотреть файл

@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// Test that a shuffle on the file works, because this used to be a bug // Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
} }
test("keys and values") {
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
} }
object ShuffleSuite { object ShuffleSuite {

Просмотреть файл

@ -3,7 +3,6 @@ package spark
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterAll
import org.scalatest.PrivateMethodTester import org.scalatest.PrivateMethodTester
import org.scalatest.matchers.ShouldMatchers
class DummyClass1 {} class DummyClass1 {}
@ -20,8 +19,17 @@ class DummyClass4(val d: DummyClass3) {
val x: Int = 0 val x: Int = 0
} }
object DummyString {
def apply(str: String) : DummyString = new DummyString(str.toArray)
}
class DummyString(val arr: Array[Char]) {
override val hashCode: Int = 0
// JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f
@transient val hash32: Int = 0
}
class SizeEstimatorSuite class SizeEstimatorSuite
extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
var oldArch: String = _ var oldArch: String = _
var oldOops: String = _ var oldOops: String = _
@ -45,15 +53,13 @@ class SizeEstimatorSuite
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
} }
// NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. // (Sun vs IBM). Use a DummyString class to make tests deterministic.
// http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
// Work around to check for either.
test("strings") { test("strings") {
SizeEstimator.estimate("") should (equal (48) or equal (40)) expect(40)(SizeEstimator.estimate(DummyString("")))
SizeEstimator.estimate("a") should (equal (56) or equal (48)) expect(48)(SizeEstimator.estimate(DummyString("a")))
SizeEstimator.estimate("ab") should (equal (56) or equal (48)) expect(48)(SizeEstimator.estimate(DummyString("ab")))
SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56)) expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
} }
test("primitive arrays") { test("primitive arrays") {
@ -105,18 +111,16 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize) val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize() SizeEstimator invokePrivate initialize()
expect(40)(SizeEstimator.estimate("")) expect(40)(SizeEstimator.estimate(DummyString("")))
expect(48)(SizeEstimator.estimate("a")) expect(48)(SizeEstimator.estimate(DummyString("a")))
expect(48)(SizeEstimator.estimate("ab")) expect(48)(SizeEstimator.estimate(DummyString("ab")))
expect(56)(SizeEstimator.estimate("abcdefgh")) expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch) resetOrClear("os.arch", arch)
} }
// NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. // (Sun vs IBM). Use a DummyString class to make tests deterministic.
// http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
// Work around to check for either.
test("64-bit arch with no compressed oops") { test("64-bit arch with no compressed oops") {
val arch = System.setProperty("os.arch", "amd64") val arch = System.setProperty("os.arch", "amd64")
val oops = System.setProperty("spark.test.useCompressedOops", "false") val oops = System.setProperty("spark.test.useCompressedOops", "false")
@ -124,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize) val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize() SizeEstimator invokePrivate initialize()
SizeEstimator.estimate("") should (equal (64) or equal (56)) expect(56)(SizeEstimator.estimate(DummyString("")))
SizeEstimator.estimate("a") should (equal (72) or equal (64)) expect(64)(SizeEstimator.estimate(DummyString("a")))
SizeEstimator.estimate("ab") should (equal (72) or equal (64)) expect(64)(SizeEstimator.estimate(DummyString("ab")))
SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72)) expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch) resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops) resetOrClear("spark.test.useCompressedOops", oops)

Просмотреть файл

@ -25,10 +25,12 @@ To mark a block of code in your markdown to be syntax highlighted by jekyll duri
// supported languages too. // supported languages too.
{% endhighlight %} {% endhighlight %}
## Scaladoc ## API Docs (Scaladoc and Epydoc)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
When you run `jekyll` in the docs directory, it will also copy over the scala doc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the SPARK_PROJECT_ROOT/pyspark directory.
NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. When you run `jekyll` in the docs directory, it will also copy over the scaladoc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs using [epydoc](http://epydoc.sourceforge.net/).
NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. Similarly, `SKIP_EPYDOC=1 jekyll` will skip PySpark API doc generation.

Просмотреть файл

@ -47,6 +47,7 @@
<li><a href="quick-start.html">Quick Start</a></li> <li><a href="quick-start.html">Quick Start</a></li>
<li><a href="scala-programming-guide.html">Scala</a></li> <li><a href="scala-programming-guide.html">Scala</a></li>
<li><a href="java-programming-guide.html">Java</a></li> <li><a href="java-programming-guide.html">Java</a></li>
<li><a href="python-programming-guide.html">Python</a></li>
<li><a href="streaming-programming-guide.html">Spark Streaming</a></li> <li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
</ul> </ul>
</li> </li>
@ -54,10 +55,9 @@
<li class="dropdown"> <li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API (Scaladoc)<b class="caret"></b></a> <a href="#" class="dropdown-toggle" data-toggle="dropdown">API (Scaladoc)<b class="caret"></b></a>
<ul class="dropdown-menu"> <ul class="dropdown-menu">
<li><a href="api/core/index.html">Spark</a></li> <li><a href="api/core/index.html">Spark Scala/Java (Scaladoc)</a></li>
<li><a href="api/examples/index.html">Spark Examples</a></li> <li><a href="api/pyspark/index.html">Spark Python (Epydoc)</a></li>
<li><a href="api/streaming/index.html">Spark Streaming</a></li> <li><a href="api/streaming/index.html">Spark Streaming Scala/Java (Scaladoc) </a></li>
<li><a href="api/bagel/index.html">Bagel</a></li>
</ul> </ul>
</li> </li>

Просмотреть файл

@ -28,3 +28,20 @@ if ENV['SKIP_SCALADOC'] != '1'
cp_r(source + "/.", dest) cp_r(source + "/.", dest)
end end
end end
if ENV['SKIP_EPYDOC'] != '1'
puts "Moving to python directory and building epydoc."
cd("../python")
puts `epydoc --config epydoc.conf`
puts "Moving back into docs dir."
cd("../docs")
puts "echo making directory pyspark"
mkdir_p "pyspark"
puts "cp -r ../python/docs/. api/pyspark"
cp_r("../python/docs/.", "api/pyspark")
cd("..")
end

Просмотреть файл

@ -9,3 +9,4 @@ Here you can find links to the Scaladoc generated for the Spark sbt subprojects.
- [Spark Examples](api/examples/index.html) - [Spark Examples](api/examples/index.html)
- [Spark Streaming](api/streaming/index.html) - [Spark Streaming](api/streaming/index.html)
- [Bagel](api/bagel/index.html) - [Bagel](api/bagel/index.html)
- [PySpark](api/pyspark/index.html)

Просмотреть файл

@ -7,11 +7,11 @@ title: Spark Overview
TODO(andyk): Rewrite to make the Java API a first class part of the story. TODO(andyk): Rewrite to make the Java API a first class part of the story.
{% endcomment %} {% endcomment %}
Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter.
interpreter. It provides clean, language-integrated APIs in Scala and Java, with a rich array of parallel operators. Spark can It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators.
run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager,
[Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html),
Amazon EC2, or without an independent resource manager ("standalone mode"). Amazon EC2, or without an independent resource manager ("standalone mode").
# Downloading # Downloading
@ -58,8 +58,15 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Quick Start](quick-start.html): a quick introduction to the Spark API; start here! * [Quick Start](quick-start.html): a quick introduction to the Spark API; start here!
* [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API * [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API
* [Streaming Programming Guide](streaming-programming-guide.html): an API preview of Spark Streaming
* [Java Programming Guide](java-programming-guide.html): using Spark from Java * [Java Programming Guide](java-programming-guide.html): using Spark from Java
* [Streaming Guide](streaming-programming-guide.html): an API preview of Spark Streaming * [Python Programming Guide](python-programming-guide.html): using Spark from Python
**API Docs:**
* [Spark Java/Scala (Scaladoc)](api/core/index.html)
* [Spark Python (Epydoc)](api/pyspark/index.html)
* [Spark Streaming Java/Scala (Scaladoc)](api/streaming/index.html)
**Deployment guides:** **Deployment guides:**
@ -73,7 +80,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Configuration](configuration.html): customize Spark via its configuration system * [Configuration](configuration.html): customize Spark via its configuration system
* [Tuning Guide](tuning.html): best practices to optimize performance and memory use * [Tuning Guide](tuning.html): best practices to optimize performance and memory use
* [API Docs (Scaladoc)](api/core/index.html)
* [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark
* [Contributing to Spark](contributing-to-spark.html) * [Contributing to Spark](contributing-to-spark.html)

Просмотреть файл

@ -0,0 +1,111 @@
---
layout: global
title: Python Programming Guide
---
The Spark Python API (PySpark) exposes most of the Spark features available in the Scala version to Python.
To learn the basics of Spark, we recommend reading through the
[Scala programming guide](scala-programming-guide.html) first; it should be
easy to follow even if you don't know Scala.
This guide will show how to use the Spark features described there in Python.
# Key Differences in the Python API
There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of different types.
* PySpark does not currently support the following Spark features:
- Accumulators
- Special functions on RDDs of doubles, such as `mean` and `stdev`
- `lookup`
- `persist` at storage levels other than `MEMORY_ONLY`
- `sample`
- `sort`
In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
{% highlight python %}
logData = sc.textFile(logFile).cache()
errors = logData.filter(lambda s: 'ERROR' in s.split())
{% endhighlight %}
You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
{% highlight python %}
def is_error(line):
return 'ERROR' in line.split()
errors = logData.filter(is_error)
{% endhighlight %}
Functions can access objects in enclosing scopes, although modifications to those objects within RDD methods will not be propagated to other tasks:
{% highlight python %}
error_keywords = ["Exception", "Error"]
def is_error(line):
words = line.split()
return any(keyword in words for keyword in error_keywords)
errors = logData.filter(is_error)
{% endhighlight %}
PySpark will automatically ship these functions to workers, along with any objects that they reference.
Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers.
The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers.
# Installing and Configuring PySpark
PySpark requires Python 2.6 or higher.
PySpark jobs are executed using a standard cPython interpreter in order to support Python modules that use C extensions.
We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/).
By default, PySpark's scripts will run programs using `python`; an alternate Python executable may be specified by setting the `PYSPARK_PYTHON` environment variable in `conf/spark-env.sh`.
All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported.
Standalone PySpark jobs should be run using the `pyspark` script, which automatically configures the Java and Python environment using the settings in `conf/spark-env.sh`.
The script automatically adds the `pyspark` package to the `PYTHONPATH`.
# Interactive Use
The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs.
When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API:
{% highlight python %}
>>> words = sc.textFile("/usr/share/dict/words")
>>> words.filter(lambda w: w.startswith("spar")).take(5)
[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass']
{% endhighlight %}
By default, the `pyspark` shell creates SparkContext that runs jobs locally.
To connect to a non-local cluster, set the `MASTER` environment variable.
For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html):
{% highlight shell %}
$ MASTER=spark://IP:PORT ./pyspark
{% endhighlight %}
# Standalone Use
PySpark can also be used from standalone Python scripts by creating a SparkContext in your script and running the script using `pyspark`.
The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job.
Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor:
{% highlight python %}
from pyspark import SparkContext
sc = SparkContext("local", "Job Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg'])
{% endhighlight %}
Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
# Where to Go from Here
PySpark includes several sample programs using the Python API in `python/examples`.
You can run them by passing the files to the `pyspark` script -- for example `./pyspark python/examples/wordcount.py`.
Each example program prints usage help when run without any arguments.
We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc.
Many of the RDD method descriptions contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples.

Просмотреть файл

@ -6,7 +6,8 @@ title: Quick Start
* This will become a table of contents (this text will be scraped). * This will become a table of contents (this text will be scraped).
{:toc} {:toc}
This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala and Java. See the [programming guide](scala-programming-guide.html) for a more complete reference. This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala, Java, and Python.
See the [programming guide](scala-programming-guide.html) for a more complete reference.
To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run: To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run:
@ -200,6 +201,16 @@ To build the job, we also write a Maven `pom.xml` file that lists Spark as a dep
<name>Simple Project</name> <name>Simple Project</name>
<packaging>jar</packaging> <packaging>jar</packaging>
<version>1.0</version> <version>1.0</version>
<repositories>
<repository>
<id>Spray.cc repository</id>
<url>http://repo.spray.cc</url>
</repository>
<repository>
<id>Typesafe repository</id>
<url>http://repo.typesafe.com/typesafe/releases</url>
</repository>
</repositories>
<dependencies> <dependencies>
<dependency> <!-- Spark dependency --> <dependency> <!-- Spark dependency -->
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>
@ -230,3 +241,40 @@ Lines with a: 8422, Lines with b: 1836
{% endhighlight %} {% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
# A Standalone Job In Python
Now we will show how to write a standalone job using the Python API (PySpark).
As an example, we'll create a simple Spark job, `SimpleJob.py`:
{% highlight python %}
"""SimpleJob.py"""
from pyspark import SparkContext
logFile = "/var/log/syslog" # Should be some file on your system
sc = SparkContext("local", "Simple job")
logData = sc.textFile(logFile).cache()
numAs = logData.filter(lambda s: 'a' in s).count()
numBs = logData.filter(lambda s: 'b' in s).count()
print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
{% endhighlight %}
This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file.
Like in the Scala and Java examples, we use a SparkContext to create RDDs.
We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference.
For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide).
`SimpleJob` is simple enough that we do not need to specify any code dependencies.
We can run this job using the `pyspark` script:
{% highlight python %}
$ cd $SPARK_HOME
$ ./pyspark SimpleJob.py
...
Lines with a: 8422, Lines with b: 1836
{% endhighlight python %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.

Просмотреть файл

@ -45,6 +45,11 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>
@ -72,6 +77,12 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.spark-project</groupId> <groupId>org.spark-project</groupId>

Просмотреть файл

@ -5,7 +5,7 @@ import spark.util.Vector
object LocalLR { object LocalLR {
val N = 10000 // Number of data points val N = 10000 // Number of data points
val D = 10 // Numer of dimensions val D = 10 // Number of dimensions
val R = 0.7 // Scaling factor val R = 0.7 // Scaling factor
val ITERATIONS = 5 val ITERATIONS = 5
val rand = new Random(42) val rand = new Random(42)

20
pom.xml
Просмотреть файл

@ -54,6 +54,7 @@
<mesos.version>0.9.0-incubating</mesos.version> <mesos.version>0.9.0-incubating</mesos.version>
<akka.version>2.0.3</akka.version> <akka.version>2.0.3</akka.version>
<spray.version>1.0-M2.1</spray.version> <spray.version>1.0-M2.1</spray.version>
<spray.json.version>1.1.1</spray.json.version>
<slf4j.version>1.6.1</slf4j.version> <slf4j.version>1.6.1</slf4j.version>
<cdh.version>4.1.2</cdh.version> <cdh.version>4.1.2</cdh.version>
</properties> </properties>
@ -222,6 +223,11 @@
<artifactId>spray-server</artifactId> <artifactId>spray-server</artifactId>
<version>${spray.version}</version> <version>${spray.version}</version>
</dependency> </dependency>
<dependency>
<groupId>cc.spray</groupId>
<artifactId>spray-json_${scala.version}</artifactId>
<version>${spray.json.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.tomdz.twirl</groupId> <groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId> <artifactId>twirl-api</artifactId>
@ -481,6 +487,12 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<properties> <properties>
<hadoop.major.version>1</hadoop.major.version> <hadoop.major.version>1</hadoop.major.version>
</properties> </properties>
@ -489,7 +501,7 @@
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId> <artifactId>hadoop-core</artifactId>
<version>0.20.205.0</version> <version>1.0.3</version>
</dependency> </dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>
@ -497,6 +509,12 @@
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<properties> <properties>
<hadoop.major.version>2</hadoop.major.version> <hadoop.major.version>2</hadoop.major.version>
</properties> </properties>

Просмотреть файл

@ -10,7 +10,7 @@ import twirl.sbt.TwirlPlugin._
object SparkBuild extends Build { object SparkBuild extends Build {
// Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or
// "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. // "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop.
val HADOOP_VERSION = "0.20.205.0" val HADOOP_VERSION = "1.0.3"
val HADOOP_MAJOR_VERSION = "1" val HADOOP_MAJOR_VERSION = "1"
// For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2"
@ -40,6 +40,7 @@ object SparkBuild extends Build {
scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue
unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
retrieveManaged := true, retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
transitiveClassifiers in Scope.GlobalScope := Seq("sources"), transitiveClassifiers in Scope.GlobalScope := Seq("sources"),
testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))),
@ -139,6 +140,7 @@ object SparkBuild extends Build {
"org.twitter4j" % "twitter4j-stream" % "3.0.2", "org.twitter4j" % "twitter4j-stream" % "3.0.2",
"cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-can" % "1.0-M2.1",
"cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1",
"cc.spray" %% "spray-json" % "1.1.1",
"org.apache.mesos" % "mesos" % "0.9.0-incubating" "org.apache.mesos" % "mesos" % "0.9.0-incubating"
) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq,
unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }

32
pyspark Executable file
Просмотреть файл

@ -0,0 +1,32 @@
#!/usr/bin/env bash
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
# Load environment variables from conf/spark-env.sh, if it exists
if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
fi
# Figure out which Python executable to use
if [ -z "$PYSPARK_PYTHON" ] ; then
PYSPARK_PYTHON="python"
fi
export PYSPARK_PYTHON
# Add the PySpark classes to the Python path:
export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP=$PYTHONSTARTUP
export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py
# Launch with `scala` by default:
if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then
export SPARK_LAUNCH_WITH_SCALA=1
fi
exec "$PYSPARK_PYTHON" "$@"

2
python/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,2 @@
*.pyc
docs/

19
python/epydoc.conf Normal file
Просмотреть файл

@ -0,0 +1,19 @@
[epydoc] # Epydoc section marker (required by ConfigParser)
# Information about the project.
name: PySpark
url: http://spark-project.org
# The list of modules to document. Modules can be named using
# dotted names, module filenames, or package directory names.
# This option may be repeated.
modules: pyspark
# Write html output to the directory "apidocs"
output: html
target: docs/
private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
pyspark.java_gateway pyspark.examples pyspark.shell

54
python/examples/kmeans.py Normal file
Просмотреть файл

@ -0,0 +1,54 @@
"""
This example requires numpy (http://www.numpy.org/)
"""
import sys
import numpy as np
from pyspark import SparkContext
def parseVector(line):
return np.array([float(x) for x in line.split(' ')])
def closestPoint(p, centers):
bestIndex = 0
closest = float("+inf")
for i in range(len(centers)):
tempDist = np.sum((p - centers[i]) ** 2)
if tempDist < closest:
closest = tempDist
bestIndex = i
return bestIndex
if __name__ == "__main__":
if len(sys.argv) < 5:
print >> sys.stderr, \
"Usage: PythonKMeans <master> <file> <k> <convergeDist>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonKMeans")
lines = sc.textFile(sys.argv[2])
data = lines.map(parseVector).cache()
K = int(sys.argv[3])
convergeDist = float(sys.argv[4])
# TODO: change this after we port takeSample()
#kPoints = data.takeSample(False, K, 34)
kPoints = data.take(K)
tempDist = 1.0
while tempDist > convergeDist:
closest = data.map(
lambda p : (closestPoint(p, kPoints), (p, 1)))
pointStats = closest.reduceByKey(
lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
newPoints = pointStats.map(
lambda (x, (y, z)): (x, y / z)).collect()
tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
for (x, y) in newPoints:
kPoints[x] = y
print "Final centers: " + str(kPoints)

Просмотреть файл

@ -0,0 +1,57 @@
"""
This example requires numpy (http://www.numpy.org/)
"""
from collections import namedtuple
from math import exp
from os.path import realpath
import sys
import numpy as np
from pyspark import SparkContext
N = 100000 # Number of data points
D = 10 # Number of dimensions
R = 0.7 # Scaling factor
ITERATIONS = 5
np.random.seed(42)
DataPoint = namedtuple("DataPoint", ['x', 'y'])
from lr import DataPoint # So that DataPoint is properly serialized
def generateData():
def generatePoint(i):
y = -1 if i % 2 == 0 else 1
x = np.random.normal(size=D) + (y * R)
return DataPoint(x, y)
return [generatePoint(i) for i in range(N)]
if __name__ == "__main__":
if len(sys.argv) == 1:
print >> sys.stderr, \
"Usage: PythonLR <master> [<slices>]"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
points = sc.parallelize(generateData(), slices).cache()
# Initialize w to a random value
w = 2 * np.random.ranf(size=D) - 1
print "Initial w: " + str(w)
def add(x, y):
x += y
return x
for i in range(1, ITERATIONS + 1):
print "On iteration %i" % i
gradient = points.map(lambda p:
(1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
).reduce(add)
w -= gradient
print "Final w: " + str(w)

21
python/examples/pi.py Normal file
Просмотреть файл

@ -0,0 +1,21 @@
import sys
from random import random
from operator import add
from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) == 1:
print >> sys.stderr, \
"Usage: PythonPi <master> [<slices>]"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonPi")
slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
n = 100000 * slices
def f(_):
x = random() * 2 - 1
y = random() * 2 - 1
return 1 if x ** 2 + y ** 2 < 1 else 0
count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
print "Pi is roughly %f" % (4.0 * count / n)

Просмотреть файл

@ -0,0 +1,50 @@
import sys
from random import Random
from pyspark import SparkContext
numEdges = 200
numVertices = 100
rand = Random(42)
def generateGraph():
edges = set()
while len(edges) < numEdges:
src = rand.randrange(0, numEdges)
dst = rand.randrange(0, numEdges)
if src != dst:
edges.add((src, dst))
return edges
if __name__ == "__main__":
if len(sys.argv) == 1:
print >> sys.stderr, \
"Usage: PythonTC <master> [<slices>]"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonTC")
slices = sys.argv[2] if len(sys.argv) > 2 else 2
tc = sc.parallelize(generateGraph(), slices).cache()
# Linear transitive closure: each round grows paths by one edge,
# by joining the graph's edges with the already-discovered paths.
# e.g. join the path (y, z) from the TC with the edge (x, y) from
# the graph to obtain the path (x, z).
# Because join() joins on keys, the edges are stored in reversed order.
edges = tc.map(lambda (x, y): (y, x))
oldCount = 0L
nextCount = tc.count()
while True:
oldCount = nextCount
# Perform the join, obtaining an RDD of (y, (z, x)) pairs,
# then project the result to obtain the new (x, z) paths.
new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
tc = tc.union(new_edges).distinct().cache()
nextCount = tc.count()
if nextCount == oldCount:
break
print "TC has %i edges" % tc.count()

Просмотреть файл

@ -0,0 +1,19 @@
import sys
from operator import add
from pyspark import SparkContext
if __name__ == "__main__":
if len(sys.argv) < 3:
print >> sys.stderr, \
"Usage: PythonWordCount <master> <file>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonWordCount")
lines = sc.textFile(sys.argv[2], 1)
counts = lines.flatMap(lambda x: x.split(' ')) \
.map(lambda x: (x, 1)) \
.reduceByKey(add)
output = counts.collect()
for (word, count) in output:
print "%s : %i" % (word, count)

Просмотреть файл

@ -0,0 +1,27 @@
Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
- The name of the author may not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

Просмотреть файл

@ -0,0 +1 @@
b7924aabe9c5e63f0a4d8bbd17019534c7ec014e

Двоичные данные
python/lib/py4j0.7.egg Normal file

Двоичный файл не отображается.

Двоичные данные
python/lib/py4j0.7.jar Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,20 @@
"""
PySpark is a Python API for Spark.
Public classes:
- L{SparkContext<pyspark.context.SparkContext>}
Main entry point for Spark functionality.
- L{RDD<pyspark.rdd.RDD>}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg"))
from pyspark.context import SparkContext
from pyspark.rdd import RDD
__all__ = ["SparkContext", "RDD"]

Просмотреть файл

@ -0,0 +1,48 @@
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.bid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
>>> large_broadcast = sc.broadcast(list(range(10000)))
"""
# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}
def _from_id(bid):
from pyspark.broadcast import _broadcastRegistry
if bid not in _broadcastRegistry:
raise Exception("Broadcast variable '%s' not loaded!" % bid)
return _broadcastRegistry[bid]
class Broadcast(object):
def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
self.value = value
self.bid = bid
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
def __reduce__(self):
self._pickle_registry.add(self)
return (_from_id, (self.bid, ))
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()

Просмотреть файл

@ -0,0 +1,974 @@
"""
This class is defined to override standard pickle functionality
The goals of it follow:
-Serialize lambdas and nested functions to compiled byte code
-Deal with main module correctly
-Deal with other non-serializable objects
It does not include an unpickler, as standard python unpickling suffices.
This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
<http://www.picloud.com>`_.
Copyright (c) 2012, Regents of the University of California.
Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the University of California, Berkeley nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import operator
import os
import pickle
import struct
import sys
import types
from functools import partial
import itertools
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import new
import dis
import traceback
#relevant opcodes
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
EXTENDED_ARG = chr(dis.EXTENDED_ARG)
import logging
cloudLog = logging.getLogger("Cloud.Transport")
try:
import ctypes
except (MemoryError, ImportError):
logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True)
ctypes = None
PyObject_HEAD = None
else:
# for reading internal structures
PyObject_HEAD = [
('ob_refcnt', ctypes.c_size_t),
('ob_type', ctypes.c_void_p),
]
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
# These helper functions were copied from PiCloud's util module.
def islambda(func):
return getattr(func,'func_name') == '<lambda>'
def xrange_params(xrangeobj):
"""Returns a 3 element tuple describing the xrange start, step, and len
respectively
Note: Only guarentees that elements of xrange are the same. parameters may
be different.
e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
though w/ iteration
"""
xrange_len = len(xrangeobj)
if not xrange_len: #empty
return (0,1,0)
start = xrangeobj[0]
if xrange_len == 1: #one element
return start, 1, 1
return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
#debug variables intended for developer use:
printSerialization = False
printMemoization = False
useForcedImports = True #Should I use forced imports for tracking?
class CloudPickler(pickle.Pickler):
dispatch = pickle.Pickler.dispatch.copy()
savedForceImports = False
savedDjangoEnv = False #hack tro transport django environment
def __init__(self, file, protocol=None, min_size_to_save= 0):
pickle.Pickler.__init__(self,file,protocol)
self.modules = set() #set of modules needed to depickle
self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
def dump(self, obj):
# note: not thread safe
# minimal side-effects, so not fixing
recurse_limit = 3000
base_recurse = sys.getrecursionlimit()
if base_recurse < recurse_limit:
sys.setrecursionlimit(recurse_limit)
self.inject_addons()
try:
return pickle.Pickler.dump(self, obj)
except RuntimeError, e:
if 'recursion' in e.args[0]:
msg = """Could not pickle object as excessively deep recursion required.
Try _fast_serialization=2 or contact PiCloud support"""
raise pickle.PicklingError(msg)
finally:
new_recurse = sys.getrecursionlimit()
if new_recurse == recurse_limit:
sys.setrecursionlimit(base_recurse)
def save_buffer(self, obj):
"""Fallback to save_string"""
pickle.Pickler.save_string(self,str(obj))
dispatch[buffer] = save_buffer
#block broken objects
def save_unsupported(self, obj, pack=None):
raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
dispatch[types.GeneratorType] = save_unsupported
#python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
try:
slice(0,1).__reduce__()
except TypeError: #can't pickle -
dispatch[slice] = save_unsupported
#itertools objects do not pickle!
for v in itertools.__dict__.values():
if type(v) is type:
dispatch[v] = save_unsupported
def save_dict(self, obj):
"""hack fix
If the dict is a global, deal with it in a special way
"""
#print 'saving', obj
if obj is __builtins__:
self.save_reduce(_get_module_builtins, (), obj=obj)
else:
pickle.Pickler.save_dict(self, obj)
dispatch[pickle.DictionaryType] = save_dict
def save_module(self, obj, pack=struct.pack):
"""
Save a module as an import
"""
#print 'try save import', obj.__name__
self.modules.add(obj)
self.save_reduce(subimport,(obj.__name__,), obj=obj)
dispatch[types.ModuleType] = save_module #new type
def save_codeobject(self, obj, pack=struct.pack):
"""
Save a code object
"""
#print 'try to save codeobj: ', obj
args = (
obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
)
self.save_reduce(types.CodeType, args, obj=obj)
dispatch[types.CodeType] = save_codeobject #new type
def save_function(self, obj, name=None, pack=struct.pack):
""" Registered with the dispatch to handle all function types.
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
write = self.write
name = obj.__name__
modname = pickle.whichmodule(obj, name)
#print 'which gives %s %s %s' % (modname, obj, name)
try:
themodule = sys.modules[modname]
except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
modname = '__main__'
if modname == '__main__':
themodule = None
if themodule:
self.modules.add(themodule)
if not self.savedDjangoEnv:
#hack for django - if we detect the settings module, we transport it
django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
if django_settings:
django_mod = sys.modules.get(django_settings)
if django_mod:
cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
self.savedDjangoEnv = True
self.modules.add(django_mod)
write(pickle.MARK)
self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
write(pickle.POP_MARK)
# if func is lambda, def'ed at prompt, is in main, or is nested, then
# we'll pickle the actual function object rather than simply saving a
# reference (as is done in default pickler), via save_function_tuple.
if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule == None:
#Force server to import modules that have been imported in main
modList = None
if themodule == None and not self.savedForceImports:
mainmod = sys.modules['__main__']
if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
modList = list(mainmod.___pyc_forcedImports__)
self.savedForceImports = True
self.save_function_tuple(obj, modList)
return
else: # func is nested
klass = getattr(themodule, name, None)
if klass is None or klass is not obj:
self.save_function_tuple(obj, [themodule])
return
if obj.__dict__:
# essentially save_reduce, but workaround needed to avoid recursion
self.save(_restore_attr)
write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
self.save(obj.__dict__)
write(pickle.TUPLE + pickle.REDUCE)
else:
write(pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
dispatch[types.FunctionType] = save_function
def save_function_tuple(self, func, forced_imports):
""" Pickles an actual func object.
A func comprises: code, globals, defaults, closure, and dict. We
extract and save these, injecting reducing functions at certain points
to recreate the func object. Keep in mind that some of these pieces
can contain a ref to the func itself. Thus, a naive save on these
pieces could trigger an infinite loop of save's. To get around that,
we first create a skeleton func object using just the code (this is
safe, since this won't contain a ref to the func), and memoize it as
soon as it's created. The other stuff can then be filled in later.
"""
save = self.save
write = self.write
# save the modules (if any)
if forced_imports:
write(pickle.MARK)
save(_modules_to_main)
#print 'forced imports are', forced_imports
forced_names = map(lambda m: m.__name__, forced_imports)
save((forced_names,))
#save((forced_imports,))
write(pickle.REDUCE)
write(pickle.POP_MARK)
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
save(_fill_function) # skeleton function updater
write(pickle.MARK) # beginning of tuple that _fill_function expects
# create a skeleton function object and memoize it
save(_make_skel_func)
save((code, len(closure), base_globals))
write(pickle.REDUCE)
self.memoize(func)
# save the rest of the func data needed by _fill_function
save(f_globals)
save(defaults)
save(closure)
save(dct)
write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple
@staticmethod
def extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
code = co.co_code
names = co.co_names
out_names = set()
n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]
i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
extended_arg = 0
i = i+2
if op == EXTENDED_ARG:
extended_arg = oparg*65536L
if op in GLOBAL_OPS:
out_names.add(names[oparg])
#print 'extracted', out_names, ' from ', names
return out_names
def extract_func_data(self, func):
"""
Turn the function into a tuple of data necessary to recreate it:
code, globals, defaults, closure, dict
"""
code = func.func_code
# extract all global ref's
func_global_refs = CloudPickler.extract_code_globals(code)
if code.co_consts: # see if nested function have any global refs
for const in code.co_consts:
if type(const) is types.CodeType and const.co_names:
func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const))
# process all variables referenced by global environment
f_globals = {}
for var in func_global_refs:
#Some names, such as class functions are not global - we don't need them
if func.func_globals.has_key(var):
f_globals[var] = func.func_globals[var]
# defaults requires no processing
defaults = func.func_defaults
def get_contents(cell):
try:
return cell.cell_contents
except ValueError, e: #cell is empty error on not yet assigned
raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
# process closure
if func.func_closure:
closure = map(get_contents, func.func_closure)
else:
closure = []
# save the dict
dct = func.func_dict
if printSerialization:
outvars = ['code: ' + str(code) ]
outvars.append('globals: ' + str(f_globals))
outvars.append('defaults: ' + str(defaults))
outvars.append('closure: ' + str(closure))
print 'function ', func, 'is extracted to: ', ', '.join(outvars)
base_globals = self.globals_ref.get(id(func.func_globals), {})
self.globals_ref[id(func.func_globals)] = base_globals
return (code, f_globals, defaults, closure, dct, base_globals)
def save_global(self, obj, name=None, pack=struct.pack):
write = self.write
memo = self.memo
if name is None:
name = obj.__name__
modname = getattr(obj, "__module__", None)
if modname is None:
modname = pickle.whichmodule(obj, name)
try:
__import__(modname)
themodule = sys.modules[modname]
except (ImportError, KeyError, AttributeError): #should never occur
raise pickle.PicklingError(
"Can't pickle %r: Module %s cannot be found" %
(obj, modname))
if modname == '__main__':
themodule = None
if themodule:
self.modules.add(themodule)
sendRef = True
typ = type(obj)
#print 'saving', obj, typ
try:
try: #Deal with case when getattribute fails with exceptions
klass = getattr(themodule, name)
except (AttributeError):
if modname == '__builtin__': #new.* are misrepeported
modname = 'new'
__import__(modname)
themodule = sys.modules[modname]
try:
klass = getattr(themodule, name)
except AttributeError, a:
#print themodule, name, obj, type(obj)
raise pickle.PicklingError("Can't pickle builtin %s" % obj)
else:
raise
except (ImportError, KeyError, AttributeError):
if typ == types.TypeType or typ == types.ClassType:
sendRef = False
else: #we can't deal with this
raise
else:
if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
sendRef = False
if not sendRef:
#note: Third party types might crash this - add better checks!
d = dict(obj.__dict__) #copy dict proxy to a dict
if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
d.pop('__dict__',None)
d.pop('__weakref__',None)
# hack as __new__ is stored differently in the __dict__
new_override = d.get('__new__', None)
if new_override:
d['__new__'] = obj.__new__
self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
d),obj=obj)
#print 'internal reduce dask %s %s' % (obj, d)
return
if self.proto >= 2:
code = _extension_registry.get((modname, name))
if code:
assert code > 0
if code <= 0xff:
write(pickle.EXT1 + chr(code))
elif code <= 0xffff:
write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
else:
write(pickle.EXT4 + pack("<i", code))
return
write(pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
dispatch[types.ClassType] = save_global
dispatch[types.BuiltinFunctionType] = save_global
dispatch[types.TypeType] = save_global
def save_instancemethod(self, obj):
#Memoization rarely is ever useful due to python bounding
self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
dispatch[types.MethodType] = save_instancemethod
def save_inst_logic(self, obj):
"""Inner logic to save instance. Based off pickle.save_inst
Supports __transient__"""
cls = obj.__class__
memo = self.memo
write = self.write
save = self.save
if hasattr(obj, '__getinitargs__'):
args = obj.__getinitargs__()
len(args) # XXX Assert it's a sequence
pickle._keep_alive(args, memo)
else:
args = ()
write(pickle.MARK)
if self.bin:
save(cls)
for arg in args:
save(arg)
write(pickle.OBJ)
else:
for arg in args:
save(arg)
write(pickle.INST + cls.__module__ + '\n' + cls.__name__ + '\n')
self.memoize(obj)
try:
getstate = obj.__getstate__
except AttributeError:
stuff = obj.__dict__
#remove items if transient
if hasattr(obj, '__transient__'):
transient = obj.__transient__
stuff = stuff.copy()
for k in list(stuff.keys()):
if k in transient:
del stuff[k]
else:
stuff = getstate()
pickle._keep_alive(stuff, memo)
save(stuff)
write(pickle.BUILD)
def save_inst(self, obj):
# Hack to detect PIL Image instances without importing Imaging
# PIL can be loaded with multiple names, so we don't check sys.modules for it
if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
self.save_image(obj)
else:
self.save_inst_logic(obj)
dispatch[types.InstanceType] = save_inst
def save_property(self, obj):
# properties not correctly saved in python
self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj)
dispatch[property] = save_property
def save_itemgetter(self, obj):
"""itemgetter serializer (needed for namedtuple support)
a bit of a pain as we need to read ctypes internals"""
class ItemGetterType(ctypes.Structure):
_fields_ = PyObject_HEAD + [
('nitems', ctypes.c_size_t),
('item', ctypes.py_object)
]
itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,))
if PyObject_HEAD:
dispatch[operator.itemgetter] = save_itemgetter
def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
"""Modified to support __transient__ on new objects
Change only affects protocol level 2 (which is always used by PiCloud"""
# Assert that args is a tuple or None
if not isinstance(args, types.TupleType):
raise pickle.PicklingError("args from reduce() should be a tuple")
# Assert that func is callable
if not hasattr(func, '__call__'):
raise pickle.PicklingError("func from reduce should be callable")
save = self.save
write = self.write
# Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
#Added fix to allow transient
cls = args[0]
if not hasattr(cls, "__new__"):
raise pickle.PicklingError(
"args[0] from __newobj__ args has no __new__")
if obj is not None and cls is not obj.__class__:
raise pickle.PicklingError(
"args[0] from __newobj__ args has the wrong class")
args = args[1:]
save(cls)
#Don't pickle transient entries
if hasattr(obj, '__transient__'):
transient = obj.__transient__
state = state.copy()
for k in list(state.keys()):
if k in transient:
del state[k]
save(args)
write(pickle.NEWOBJ)
else:
save(func)
save(args)
write(pickle.REDUCE)
if obj is not None:
self.memoize(obj)
# More new special cases (that work with older protocols as
# well): when __reduce__ returns a tuple with 4 or 5 items,
# the 4th and 5th item should be iterators that provide list
# items and dict items (as (key, value) tuples), or None.
if listitems is not None:
self._batch_appends(listitems)
if dictitems is not None:
self._batch_setitems(dictitems)
if state is not None:
#print 'obj %s has state %s' % (obj, state)
save(state)
write(pickle.BUILD)
def save_xrange(self, obj):
"""Save an xrange object in python 2.5
Python 2.6 supports this natively
"""
range_params = xrange_params(obj)
self.save_reduce(_build_xrange,range_params)
#python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
try:
xrange(0).__reduce__()
except TypeError: #can't pickle -- use PiCloud pickler
dispatch[xrange] = save_xrange
def save_partial(self, obj):
"""Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
if sys.version_info < (2,7): #2.7 supports partial pickling
dispatch[partial] = save_partial
def save_file(self, obj):
"""Save a file"""
import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
from ..transport.adapter import SerializingAdapter
if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
if obj.name == '<stdout>':
return self.save_reduce(getattr, (sys,'stdout'), obj=obj)
if obj.name == '<stderr>':
return self.save_reduce(getattr, (sys,'stderr'), obj=obj)
if obj.name == '<stdin>':
raise pickle.PicklingError("Cannot pickle standard input")
if hasattr(obj, 'isatty') and obj.isatty():
raise pickle.PicklingError("Cannot pickle files that map to tty objects")
if 'r' not in obj.mode:
raise pickle.PicklingError("Cannot pickle files that are not opened for reading")
name = obj.name
try:
fsize = os.stat(name).st_size
except OSError:
raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name)
if obj.closed:
#create an empty closed string io
retval = pystringIO.StringIO("")
retval.close()
elif not fsize: #empty file
retval = pystringIO.StringIO("")
try:
tmpfile = file(name)
tst = tmpfile.read(1)
except IOError:
raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
tmpfile.close()
if tst != '':
raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name)
elif fsize > SerializingAdapter.max_transmit_data:
raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" %
(name,SerializingAdapter.max_transmit_data))
else:
try:
tmpfile = file(name)
contents = tmpfile.read(SerializingAdapter.max_transmit_data)
tmpfile.close()
except IOError:
raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
retval = pystringIO.StringIO(contents)
curloc = obj.tell()
retval.seek(curloc)
retval.name = name
self.save(retval) #save stringIO
self.memoize(obj)
dispatch[file] = save_file
"""Special functions for Add-on libraries"""
def inject_numpy(self):
numpy = sys.modules.get('numpy')
if not numpy or not hasattr(numpy, 'ufunc'):
return
self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
numpy_tst_mods = ['numpy', 'scipy.special']
def save_ufunc(self, obj):
"""Hack function for saving numpy ufunc objects"""
name = obj.__name__
for tst_mod_name in self.numpy_tst_mods:
tst_mod = sys.modules.get(tst_mod_name, None)
if tst_mod:
if name in tst_mod.__dict__:
self.save_reduce(_getobject, (tst_mod_name, name))
return
raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
def inject_timeseries(self):
"""Handle bugs with pickling scikits timeseries"""
tseries = sys.modules.get('scikits.timeseries.tseries')
if not tseries or not hasattr(tseries, 'Timeseries'):
return
self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
def save_timeseries(self, obj):
import scikits.timeseries.tseries as ts
func, reduce_args, state = obj.__reduce__()
if func != ts._tsreconstruct:
raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
state = (1,
obj.shape,
obj.dtype,
obj.flags.fnc,
obj._data.tostring(),
ts.getmaskarray(obj).tostring(),
obj._fill_value,
obj._dates.shape,
obj._dates.__array__().tostring(),
obj._dates.dtype, #added -- preserve type
obj.freq,
obj._optinfo,
)
return self.save_reduce(_genTimeSeries, (reduce_args, state))
def inject_email(self):
"""Block email LazyImporters from being saved"""
email = sys.modules.get('email')
if not email:
return
self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
self.inject_numpy()
self.inject_timeseries()
self.inject_email()
"""Python Imaging Library"""
def save_image(self, obj):
if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
#if image not loaded yet -- lazy load
self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
else:
#image is loaded - just transmit it over
self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
"""
def memoize(self, obj):
pickle.Pickler.memoize(self, obj)
if printMemoization:
print 'memoizing ' + str(obj)
"""
# Shorthands for legacy support
def dump(obj, file, protocol=2):
CloudPickler(file, protocol).dump(obj)
def dumps(obj, protocol=2):
file = StringIO()
cp = CloudPickler(file,protocol)
cp.dump(obj)
#print 'cloud dumped', str(obj), str(cp.modules)
return file.getvalue()
#hack for __import__ not working as desired
def subimport(name):
__import__(name)
return sys.modules[name]
#hack to load django settings:
def django_settings_load(name):
modified_env = False
if 'DJANGO_SETTINGS_MODULE' not in os.environ:
os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
modified_env = True
try:
module = subimport(name)
except Exception, i:
print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
print_exec(sys.stderr)
if modified_env:
del os.environ['DJANGO_SETTINGS_MODULE']
else:
#add project directory to sys,path:
if hasattr(module,'__file__'):
dirname = os.path.split(module.__file__)[0] + '/'
sys.path.append(dirname)
# restores function attributes
def _restore_attr(obj, attr):
for key, val in attr.items():
setattr(obj, key, val)
return obj
def _get_module_builtins():
return pickle.__builtins__
def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
def _modules_to_main(modList):
"""Force every module in modList to be placed into main"""
if not modList:
return
main = sys.modules['__main__']
for modname in modList:
if type(modname) is str:
try:
mod = __import__(modname)
except Exception, i: #catch all...
sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
A version mismatch is likely. Specific error was:\n' % modname)
print_exec(sys.stderr)
else:
setattr(main,mod.__name__, mod)
else:
#REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
#In old version actual module was sent
setattr(main,modname.__name__, modname)
#object generators:
def _build_xrange(start, step, len):
"""Built xrange explicitly"""
return xrange(start, start + step*len, step)
def _genpartial(func, args, kwds):
if not args:
args = ()
if not kwds:
kwds = {}
return partial(func, *args, **kwds)
def _fill_function(func, globals, defaults, closure, dict):
""" Fills in the rest of function data into the skeleton function object
that were created via _make_skel_func().
"""
func.func_globals.update(globals)
func.func_defaults = defaults
func.func_dict = dict
if len(closure) != len(func.func_closure):
raise pickle.UnpicklingError("closure lengths don't match up")
for i in range(len(closure)):
_change_cell_value(func.func_closure[i], closure[i])
return func
def _make_skel_func(code, num_closures, base_globals = None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
#build closure (cells):
if not ctypes:
raise Exception('ctypes failed to import; cannot build function')
cellnew = ctypes.pythonapi.PyCell_New
cellnew.restype = ctypes.py_object
cellnew.argtypes = (ctypes.py_object,)
dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures)))
if base_globals is None:
base_globals = {}
base_globals['__builtins__'] = __builtins__
return types.FunctionType(code, base_globals,
None, None, dummy_closure)
# this piece of opaque code is needed below to modify 'cell' contents
cell_changer_code = new.code(
1, 1, 2, 0,
''.join([
chr(dis.opmap['LOAD_FAST']), '\x00\x00',
chr(dis.opmap['DUP_TOP']),
chr(dis.opmap['STORE_DEREF']), '\x00\x00',
chr(dis.opmap['RETURN_VALUE'])
]),
(), (), ('newval',), '<nowhere>', 'cell_changer', 1, '', ('c',), ()
)
def _change_cell_value(cell, newval):
""" Changes the contents of 'cell' object to newval """
return new.function(cell_changer_code, {}, None, (), (cell,))(newval)
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
def _getobject(modname, attribute):
mod = __import__(modname)
return mod.__dict__[attribute]
def _generateImage(size, mode, str_rep):
"""Generate image from string representation"""
import Image
i = Image.new(mode, size)
i.fromstring(str_rep)
return i
def _lazyloadImage(fp):
import Image
fp.seek(0) #works in almost any case
return Image.open(fp)
"""Timeseries"""
def _genTimeSeries(reduce_args, state):
import scikits.timeseries.tseries as ts
from numpy import ndarray
from numpy.ma import MaskedArray
time_series = ts._tsreconstruct(*reduce_args)
#from setstate modified
(ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
#print 'regenerating %s' % dtyp
MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
_dates = time_series._dates
#_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
_dates.freq = frq
_dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
toobj=None, toord=None, tostr=None))
# Update the _optinfo dictionary
time_series._optinfo.update(infodict)
return time_series

159
python/pyspark/context.py Normal file
Просмотреть файл

@ -0,0 +1,159 @@
import os
import atexit
from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
from py4j.java_collections import ListConverter
class SparkContext(object):
"""
Main entry point for Spark functionality. A SparkContext represents the
connection to a Spark cluster, and can be used to create L{RDD}s and
broadcast variables on that cluster.
"""
gateway = launch_gateway()
jvm = gateway.jvm
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
"""
Create a new SparkContext.
@param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
@param jobName: A name for your job, to display on the cluster web UI
@param sparkHome: Location where Spark is installed on cluster nodes.
@param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
system or HDFS, HTTP, HTTPS, or FTP URLs.
@param environment: A dictionary of environment variables to set on
worker nodes.
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
empty_string_array = self.gateway.new_array(self.jvm.String, 0)
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
@property
def defaultParallelism(self):
"""
Default level of parallelism to use when not given by user (e.g. for
reduce tasks)
"""
return self._jsc.sc().defaultParallelism()
def __del__(self):
if self._jsc:
self._jsc.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
self._jsc.stop()
self._jsc = None
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD.
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
atexit.register(lambda: os.unlink(tempFile.name))
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
"""
Read a text file from HDFS, a local file system (available on all
nodes), or any Hadoop-supported file system URI, and return it as an
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
def union(self, rdds):
"""
Build the union of a list of RDDs.
"""
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
rest = ListConverter().convert(rest, self.gateway._gateway_client)
return RDD(self._jsc.union(first, rest), self)
def broadcast(self, value):
"""
Broadcast a read-only variable to the cluster, returning a C{Broadcast}
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark
job on every node. The C{path} passed can be either a local file,
a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
HTTPS or FTP URI.
"""
self._jsc.sc().addFile(path)
def clearFiles(self):
"""
Clear the job's list of files added by L{addFile} or L{addPyFile} so
that they do not get downloaded to any new nodes.
"""
# TODO: remove added .py or .zip files from the PYTHONPATH?
self._jsc.sc().clearFiles()
def addPyFile(self, path):
"""
Add a .py or .zip dependency for all tasks to be executed on this
SparkContext in the future. The C{path} passed can be either a local
file, a file in HDFS (or other Hadoop-supported filesystems), or an
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])

Просмотреть файл

@ -0,0 +1,38 @@
import os
import sys
from subprocess import Popen, PIPE
from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
SPARK_HOME = os.environ["SPARK_HOME"]
def launch_gateway():
# Launch the Py4j gateway using Spark's run command so that we pick up the
# proper classpath and SPARK_MEM settings from spark-env.sh
command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer",
"--die-on-broken-pipe", "0"]
proc = Popen(command, stdout=PIPE, stdin=PIPE)
# Determine which ephemeral port the server started on:
port = int(proc.stdout.readline())
# Create a thread to echo output from the GatewayServer, which is required
# for Java log output to show up:
class EchoOutputThread(Thread):
def __init__(self, stream):
Thread.__init__(self)
self.daemon = True
self.stream = stream
def run(self):
while True:
line = self.stream.readline()
sys.stderr.write(line)
EchoOutputThread(proc.stdout).start()
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
java_import(gateway.jvm, "spark.api.java.*")
java_import(gateway.jvm, "spark.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway

92
python/pyspark/join.py Normal file
Просмотреть файл

@ -0,0 +1,92 @@
"""
Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of the Douban Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
def _do_python_join(rdd, other, numSplits, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
def python_join(rdd, other, numSplits):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
return [(v, w) for v in vbuf for w in wbuf]
return _do_python_join(rdd, other, numSplits, dispatch)
def python_right_outer_join(rdd, other, numSplits):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
if not vbuf:
vbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
return _do_python_join(rdd, other, numSplits, dispatch)
def python_left_outer_join(rdd, other, numSplits):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
if not wbuf:
wbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
return _do_python_join(rdd, other, numSplits, dispatch)
def python_cogroup(rdd, other, numSplits):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
if n == 1:
vbuf.append(v)
elif n == 2:
wbuf.append(v)
return (vbuf, wbuf)
return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)

723
python/pyspark/rdd.py Normal file
Просмотреть файл

@ -0,0 +1,723 @@
import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap, product
import operator
import os
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
class RDD(object):
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
Represents an immutable, partitioned collection of elements that can be
operated on in parallel.
"""
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
self.ctx = ctx
@property
def context(self):
"""
The L{SparkContext} that this RDD was created on.
"""
return self.ctx
def cache(self):
"""
Persist this RDD with the default storage level (C{MEMORY_ONLY}).
"""
self.is_cached = True
self._jrdd.cache()
return self
# TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False):
"""
Return a new RDD containing the distinct elements in this RDD.
"""
def func(split, iterator): return imap(f, iterator)
return PipelinedRDD(self, func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
"""
Return a new RDD by first applying a function to all elements of this
RDD, and then flattening the results.
>>> rdd = sc.parallelize([2, 3, 4])
>>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
[1, 1, 1, 2, 2, 3]
>>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator): return chain.from_iterable(imap(f, iterator))
return self.mapPartitionsWithSplit(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition of this RDD.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> def f(iterator): yield sum(iterator)
>>> rdd.mapPartitions(f).collect()
[3, 7]
"""
def func(s, iterator): return f(iterator)
return self.mapPartitionsWithSplit(func)
def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition of this RDD,
while tracking the index of the original partition.
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(splitIndex, iterator): yield splitIndex
>>> rdd.mapPartitionsWithSplit(f).sum()
6
"""
return PipelinedRDD(self, f, preservesPartitioning)
def filter(self, f):
"""
Return a new RDD containing only the elements that satisfy a predicate.
>>> rdd = sc.parallelize([1, 2, 3, 4, 5])
>>> rdd.filter(lambda x: x % 2 == 0).collect()
[2, 4]
"""
def func(iterator): return ifilter(f, iterator)
return self.mapPartitions(func)
def distinct(self):
"""
Return a new RDD containing the distinct elements in this RDD.
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
return self.map(lambda x: (x, "")) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
# TODO: sampling needs to be re-implemented due to Batch
#def sample(self, withReplacement, fraction, seed):
# jrdd = self._jrdd.sample(withReplacement, fraction, seed)
# return RDD(jrdd, self.ctx)
#def takeSample(self, withReplacement, num, seed):
# vals = self._jrdd.takeSample(withReplacement, num, seed)
# return [load_pickle(bytes(x)) for x in vals]
def union(self, other):
"""
Return the union of this RDD and another one.
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
def __add__(self, other):
"""
Return the union of this RDD and another one.
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> (rdd + rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
if not isinstance(other, RDD):
raise TypeError
return self.union(other)
# TODO: sort
def glom(self):
"""
Return an RDD created by coalescing all elements within each partition
into a list.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> sorted(rdd.glom().collect())
[[1, 2], [3, 4]]
"""
def func(iterator): yield list(iterator)
return self.mapPartitions(func)
def cartesian(self, other):
"""
Return the Cartesian product of this RDD and another one, that is, the
RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and
C{b} is in C{other}.
>>> rdd = sc.parallelize([1, 2])
>>> sorted(rdd.cartesian(rdd).collect())
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
def unpack_batches(pair):
(x, y) = pair
if type(x) == Batch or type(y) == Batch:
xs = x.items if type(x) == Batch else [x]
ys = y.items if type(y) == Batch else [y]
for pair in product(xs, ys):
yield pair
else:
yield pair
return java_cartesian.flatMap(unpack_batches)
def groupBy(self, f, numSplits=None):
"""
Return an RDD of grouped items.
>>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
>>> result = rdd.groupBy(lambda x: x % 2).collect()
>>> sorted([(x, sorted(y)) for (x, y) in result])
[(0, [2, 8]), (1, [1, 1, 3, 5])]
"""
return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
def pipe(self, command, env={}):
"""
Return an RDD created by piping elements to a forked external process.
>>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
['1', '2', '3']
"""
def func(iterator):
pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
def pipe_objs(out):
for obj in iterator:
out.write(str(obj).rstrip('\n') + '\n')
out.close()
Thread(target=pipe_objs, args=[pipe.stdin]).start()
return (x.rstrip('\n') for x in pipe.stdout)
return self.mapPartitions(func)
def foreach(self, f):
"""
Applies a function to all elements of this RDD.
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
self.map(f).collect() # Force evaluation
def collect(self):
"""
Return a list that contains all of the elements in this RDD.
"""
picklesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(picklesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False)
tempFile.close()
def clean_up_file():
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
yield item
os.unlink(tempFile.name)
def reduce(self, f):
"""
Reduces the elements of this RDD using the specified associative binary
operator.
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
15
>>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
10
"""
def func(iterator):
acc = None
for obj in iterator:
if acc is None:
acc = obj
else:
acc = f(obj, acc)
if acc is not None:
yield acc
vals = self.mapPartitions(func).collect()
return reduce(f, vals)
def fold(self, zeroValue, op):
"""
Aggregate the elements of each partition, and then the results for all
the partitions, using a given associative function and a neutral "zero
value."
The function C{op(t1, t2)} is allowed to modify C{t1} and return it
as its result value to avoid object allocation; however, it should not
modify C{t2}.
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
def func(iterator):
acc = zeroValue
for obj in iterator:
acc = op(obj, acc)
yield acc
vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)
# TODO: aggregate
def sum(self):
"""
Add up the elements in this RDD.
>>> sc.parallelize([1.0, 2.0, 3.0]).sum()
6.0
"""
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
def count(self):
"""
Return the number of elements in this RDD.
>>> sc.parallelize([2, 3, 4]).count()
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
def countByValue(self):
"""
Return the count of each unique value in this RDD as a dictionary of
(value, count) pairs.
>>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
[(1, 2), (2, 3)]
"""
def countPartition(iterator):
counts = defaultdict(int)
for obj in iterator:
counts[obj] += 1
yield counts
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] += v
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
def take(self, num):
"""
Take the first num elements of the RDD.
This currently scans the partitions *one by one*, so it will be slow if
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
>>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
[2, 3]
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
items = []
for partition in range(self._jrdd.splits().size()):
iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
def first(self):
"""
Return the first element in this RDD.
>>> sc.parallelize([2, 3, 4]).first()
2
"""
return self.take(1)[0]
def saveAsTextFile(self, path):
"""
Save this RDD as a text file, using string representations of elements.
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.close()
>>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
>>> from fileinput import input
>>> from glob import glob
>>> ''.join(input(glob(tempFile.name + "/part-0000*")))
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
"""
def func(split, iterator):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
def collectAsMap(self):
"""
Return the key-value pairs in this RDD to the master as a dictionary.
>>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
>>> m[1]
2
>>> m[3]
4
"""
return dict(self.collect())
def reduceByKey(self, func, numSplits=None):
"""
Merge the values for each key using an associative reduce function.
This will also perform the merging locally on each mapper before
sending results to a reducer, similarly to a "combiner" in MapReduce.
Output will be hash-partitioned with C{numSplits} splits, or the
default parallelism level if C{numSplits} is not specified.
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
return self.combineByKey(lambda x: x, func, func, numSplits)
def reduceByKeyLocally(self, func):
"""
Merge the values for each key using an associative reduce function, but
return the results immediately to the master as a dictionary.
This will also perform the merging locally on each mapper before
sending results to a reducer, similarly to a "combiner" in MapReduce.
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
def reducePartition(iterator):
m = {}
for (k, v) in iterator:
m[k] = v if k not in m else func(m[k], v)
yield m
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] = v if k not in m1 else func(m1[k], v)
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)
def countByKey(self):
"""
Count the number of elements for each key, and return the result to the
master as a dictionary.
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.countByKey().items())
[('a', 2), ('b', 1)]
"""
return self.map(lambda x: x[0]).countByValue()
def join(self, other, numSplits=None):
"""
Return an RDD containing all pairs of elements with matching keys in
C{self} and C{other}.
Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
(k, v1) is in C{self} and (k, v2) is in C{other}.
Performs a hash join across the cluster.
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2), ("a", 3)])
>>> sorted(x.join(y).collect())
[('a', (1, 2)), ('a', (1, 3))]
"""
return python_join(self, other, numSplits)
def leftOuterJoin(self, other, numSplits=None):
"""
Perform a left outer join of C{self} and C{other}.
For each element (k, v) in C{self}, the resulting RDD will either
contain all pairs (k, (v, w)) for w in C{other}, or the pair
(k, (v, None)) if no elements in other have key k.
Hash-partitions the resulting RDD into the given number of partitions.
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(x.leftOuterJoin(y).collect())
[('a', (1, 2)), ('b', (4, None))]
"""
return python_left_outer_join(self, other, numSplits)
def rightOuterJoin(self, other, numSplits=None):
"""
Perform a right outer join of C{self} and C{other}.
For each element (k, w) in C{other}, the resulting RDD will either
contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
if no elements in C{self} have key k.
Hash-partitions the resulting RDD into the given number of partitions.
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(y.rightOuterJoin(x).collect())
[('a', (2, 1)), ('b', (None, 4))]
"""
return python_right_outer_join(self, other, numSplits)
# TODO: add option to control map-side combining
def partitionBy(self, numSplits, hashFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
>>> sets = pairs.partitionBy(2).glom().collect()
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numSplits) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
buckets[hashFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None):
"""
Generic function to combine the elements for each key using a custom
set of aggregation functions.
Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
type" C. Note that V and C can be different -- for example, one might
group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]).
Users provide three functions:
- C{createCombiner}, which turns a V into a C (e.g., creates
a one-element list)
- C{mergeValue}, to merge a V into a C (e.g., adds it to the end of
a list)
- C{mergeCombiners}, to combine two C's into a single one.
In addition, users can control the partitioning of the output RDD.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> def f(x): return x
>>> def add(a, b): return a + str(b)
>>> sorted(x.combineByKey(str, add, add).collect())
[('a', '11'), ('b', '1')]
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
for (k, v) in iterator:
if k not in combiners:
combiners[k] = createCombiner(v)
else:
combiners[k] = mergeValue(combiners[k], v)
return combiners.iteritems()
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numSplits)
def _mergeCombiners(iterator):
combiners = {}
for (k, v) in iterator:
if not k in combiners:
combiners[k] = v
else:
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
# TODO: support variant with custom partitioner
def groupByKey(self, numSplits=None):
"""
Group the values for each key in the RDD into a single sequence.
Hash-partitions the resulting RDD with into numSplits partitions.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
[('a', [1, 1]), ('b', [1])]
"""
def createCombiner(x):
return [x]
def mergeValue(xs, x):
xs.append(x)
return xs
def mergeCombiners(a, b):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numSplits)
# TODO: add tests
def flatMapValues(self, f):
"""
Pass each value in the key-value pair RDD through a flatMap function
without changing the keys; this also retains the original RDD's
partitioning.
"""
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
return self.flatMap(flat_map_fn, preservesPartitioning=True)
def mapValues(self, f):
"""
Pass each value in the key-value pair RDD through a map function
without changing the keys; this also retains the original RDD's
partitioning.
"""
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)
# TODO: support varargs cogroup of several RDDs.
def groupWith(self, other):
"""
Alias for cogroup.
"""
return self.cogroup(other)
# TODO: add variant with custom parittioner
def cogroup(self, other, numSplits=None):
"""
For each key k in C{self} or C{other}, return a resulting RDD that
contains a tuple with the list of values for that key in C{self} as well
as C{other}.
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numSplits)
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
class PipelinedRDD(RDD):
"""
Pipelined maps:
>>> rdd = sc.parallelize([1, 2, 3, 4])
>>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
[4, 8, 12, 16]
>>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
[4, 8, 12, 16]
Pipelined reduces:
>>> from operator import add
>>> rdd.map(lambda x: 2 * x).reduce(add)
20
>>> rdd.flatMap(lambda x: [x, x]).reduce(add)
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and not prev.is_cached:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
self._prev_jrdd = prev._prev_jrdd
else:
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
func = self.func
if not self._bypass_serializer and self.ctx.batchSize != 1:
oldfunc = self.func
batchSize = self.ctx.batchSize
def batched_func(split, iterator):
return batched(oldfunc(split, iterator), batchSize)
func = batched_func
cmds = [func, self._bypass_serializer]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx.gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
def _test():
import doctest
from pyspark.context import SparkContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
doctest.testmod(globs=globs)
globs['sc'].stop()
if __name__ == "__main__":
_test()

Просмотреть файл

@ -0,0 +1,78 @@
import struct
import cPickle
class Batch(object):
"""
Used to store multiple RDD entries as a single Java object.
This relieves us from having to explicitly track whether an RDD
is stored as batches of objects and avoids problems when processing
the union() of batched and unbatched RDDs (e.g. the union() of textFile()
with another RDD).
"""
def __init__(self, items):
self.items = items
def batched(iterator, batchSize):
if batchSize == -1: # unlimited batch size
yield Batch(list(iterator))
else:
items = []
count = 0
for item in iterator:
items.append(item)
count += 1
if count == batchSize:
yield Batch(items)
items = []
count = 0
if items:
yield Batch(items)
def dump_pickle(obj):
return cPickle.dumps(obj, 2)
load_pickle = cPickle.loads
def read_long(stream):
length = stream.read(8)
if length == "":
raise EOFError
return struct.unpack("!q", length)[0]
def read_int(stream):
length = stream.read(4)
if length == "":
raise EOFError
return struct.unpack("!i", length)[0]
def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj)))
stream.write(obj)
def read_with_length(stream):
length = read_int(stream)
obj = stream.read(length)
if obj == "":
raise EOFError
return obj
def read_from_pickle_file(stream):
try:
while True:
obj = load_pickle(read_with_length(stream))
if type(obj) == Batch: # We don't care about inheritance
for item in obj.items:
yield item
else:
yield obj
except EOFError:
return

17
python/pyspark/shell.py Normal file
Просмотреть файл

@ -0,0 +1,17 @@
"""
An interactive shell.
This fle is designed to be launched as a PYTHONSTARTUP script.
"""
import os
from pyspark.context import SparkContext
sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell")
print "Spark context avaiable as sc."
# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
# which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
if _pythonstartup and os.path.isfile(_pythonstartup):
execfile(_pythonstartup)

42
python/pyspark/worker.py Normal file
Просмотреть файл

@ -0,0 +1,42 @@
"""
Worker that receives input from Piped RDD.
"""
import sys
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = sys.stdout
sys.stdout = sys.stderr
def load_obj():
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def main():
split_index = read_int(sys.stdin)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
value = read_with_length(sys.stdin)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
func = load_obj()
bypassSerializer = load_obj()
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin)
for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout)
if __name__ == '__main__':
main()

26
python/run-tests Executable file
Просмотреть файл

@ -0,0 +1,26 @@
#!/usr/bin/env bash
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; cd ../; pwd)"
FAILED=0
$FWDIR/pyspark pyspark/rdd.py
FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
echo -en "\033[0m" # No color
exit -1
else
echo -en "\033[32m" # Green
echo "Tests passed."
echo -en "\033[0m" # No color
fi
# TODO: in the long-run, it would be nice to use a test runner like `nose`.
# The doctest fixtures are the current barrier to doing this.

Просмотреть файл

@ -70,6 +70,11 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<properties> <properties>
<classifier>hadoop1</classifier> <classifier>hadoop1</classifier>
</properties> </properties>
@ -110,6 +115,12 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<properties> <properties>
<classifier>hadoop2</classifier> <classifier>hadoop2</classifier>
</properties> </properties>

Просмотреть файл

@ -72,6 +72,11 @@
<profiles> <profiles>
<profile> <profile>
<id>hadoop1</id> <id>hadoop1</id>
<activation>
<property>
<name>!hadoopVersion</name>
</property>
</activation>
<properties> <properties>
<classifier>hadoop1</classifier> <classifier>hadoop1</classifier>
</properties> </properties>
@ -116,6 +121,12 @@
</profile> </profile>
<profile> <profile>
<id>hadoop2</id> <id>hadoop2</id>
<activation>
<property>
<name>hadoopVersion</name>
<value>2</value>
</property>
</activation>
<properties> <properties>
<classifier>hadoop2</classifier> <classifier>hadoop2</classifier>
</properties> </properties>

16
run
Просмотреть файл

@ -64,6 +64,7 @@ REPL_DIR="$FWDIR/repl"
EXAMPLES_DIR="$FWDIR/examples" EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel" BAGEL_DIR="$FWDIR/bagel"
STREAMING_DIR="$FWDIR/streaming" STREAMING_DIR="$FWDIR/streaming"
PYSPARK_DIR="$FWDIR/python"
# Build up classpath # Build up classpath
CLASSPATH="$SPARK_CLASSPATH" CLASSPATH="$SPARK_CLASSPATH"
@ -77,20 +78,17 @@ CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
if [ -e "$FWDIR/lib_managed" ]; then if [ -e "$FWDIR/lib_managed" ]; then
for jar in `find "$FWDIR/lib_managed/jars" -name '*jar'`; do CLASSPATH+=":$FWDIR/lib_managed/jars/*"
CLASSPATH+=":$jar" CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
done
for jar in `find "$FWDIR/lib_managed/bundles" -name '*jar'`; do
CLASSPATH+=":$jar"
done
fi fi
for jar in `find "$REPL_DIR/lib" -name '*jar'`; do CLASSPATH+=":$REPL_DIR/lib/*"
CLASSPATH+=":$jar"
done
for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH+=":$jar" CLASSPATH+=":$jar"
done done
CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH+=":$jar"
done
export CLASSPATH # Needed for spark-shell export CLASSPATH # Needed for spark-shell
# Figure out whether to run our class with java or with the scala launcher. # Figure out whether to run our class with java or with the scala launcher.

Просмотреть файл

@ -34,6 +34,7 @@ set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel set BAGEL_DIR=%FWDIR%bagel
set PYSPARK_DIR=%FWDIR%python
rem Build up classpath rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
@ -42,6 +43,7 @@ set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMP
for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%PYSPARK_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
rem Figure out whether to run our class with java or with the scala launcher. rem Figure out whether to run our class with java or with the scala launcher.