зеркало из https://github.com/microsoft/spark.git
Use numpy in Python k-means example.
This commit is contained in:
Родитель
fd94e5443c
Коммит
607b53abfc
|
@ -101,7 +101,13 @@ trait PythonRDDBase {
|
|||
stream.readFully(obj)
|
||||
obj
|
||||
} catch {
|
||||
case eof: EOFException => { new Array[Byte](0) }
|
||||
case eof: EOFException => {
|
||||
val exitStatus = proc.waitFor()
|
||||
if (exitStatus != 0) {
|
||||
throw new Exception("Subprocess exited with status " + exitStatus)
|
||||
}
|
||||
new Array[Byte](0)
|
||||
}
|
||||
case e => throw e
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,25 +1,18 @@
|
|||
import sys
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from numpy import array, sum as np_sum
|
||||
|
||||
|
||||
def parseVector(line):
|
||||
return [float(x) for x in line.split(' ')]
|
||||
|
||||
|
||||
def addVec(x, y):
|
||||
return [a + b for (a, b) in zip(x, y)]
|
||||
|
||||
|
||||
def squaredDist(x, y):
|
||||
return sum((a - b) ** 2 for (a, b) in zip(x, y))
|
||||
return array([float(x) for x in line.split(' ')])
|
||||
|
||||
|
||||
def closestPoint(p, centers):
|
||||
bestIndex = 0
|
||||
closest = float("+inf")
|
||||
for i in range(len(centers)):
|
||||
tempDist = squaredDist(p, centers[i])
|
||||
tempDist = np_sum((p - centers[i]) ** 2)
|
||||
if tempDist < closest:
|
||||
closest = tempDist
|
||||
bestIndex = i
|
||||
|
@ -41,14 +34,14 @@ if __name__ == "__main__":
|
|||
tempDist = 1.0
|
||||
|
||||
while tempDist > convergeDist:
|
||||
closest = data.mapPairs(
|
||||
closest = data.map(
|
||||
lambda p : (closestPoint(p, kPoints), (p, 1)))
|
||||
pointStats = closest.reduceByKey(
|
||||
lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2))
|
||||
newPoints = pointStats.mapPairs(
|
||||
lambda (x, (y, z)): (x, [a / z for a in y])).collect()
|
||||
lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
|
||||
newPoints = pointStats.map(
|
||||
lambda (x, (y, z)): (x, y / z)).collect()
|
||||
|
||||
tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints)
|
||||
tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
|
||||
|
||||
for (x, y) in newPoints:
|
||||
kPoints[x] = y
|
||||
|
|
|
@ -71,7 +71,7 @@ class RDD(object):
|
|||
|
||||
def takeSample(self, withReplacement, num, seed):
|
||||
vals = self._jrdd.takeSample(withReplacement, num, seed)
|
||||
return [PickleSerializer.loads(x) for x in vals]
|
||||
return [PickleSerializer.loads(bytes(x)) for x in vals]
|
||||
|
||||
def union(self, other):
|
||||
"""
|
||||
|
@ -218,17 +218,16 @@ class RDD(object):
|
|||
|
||||
# TODO: pipelining
|
||||
# TODO: optimizations
|
||||
def shuffle(self, numSplits):
|
||||
def shuffle(self, numSplits, hashFunc=hash):
|
||||
if numSplits is None:
|
||||
numSplits = self.ctx.defaultParallelism
|
||||
pipe_command = RDD._get_pipe_command('shuffle_map_step', [])
|
||||
pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc])
|
||||
class_manifest = self._jrdd.classManifest()
|
||||
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
|
||||
pipe_command, False, self.ctx.pythonExec, class_manifest)
|
||||
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
|
||||
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
|
||||
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
|
||||
# TODO: extract second value.
|
||||
return RDD(jrdd, self.ctx)
|
||||
|
||||
|
||||
|
@ -277,8 +276,6 @@ class RDD(object):
|
|||
map_values_fn = lambda (k, v): (k, f(v))
|
||||
return self.map(map_values_fn, preservesPartitioning=True)
|
||||
|
||||
# TODO: implement shuffle.
|
||||
|
||||
# TODO: support varargs cogroup of several RDDs.
|
||||
def groupWith(self, other):
|
||||
return self.cogroup(other)
|
||||
|
|
|
@ -48,9 +48,6 @@ def do_map(flat=False):
|
|||
f = load_function()
|
||||
for obj in read_input():
|
||||
try:
|
||||
#from pickletools import dis
|
||||
#print repr(obj)
|
||||
#print dis(obj)
|
||||
out = f(PickleSerializer.loads(obj))
|
||||
if out is not None:
|
||||
if flat:
|
||||
|
@ -64,9 +61,10 @@ def do_map(flat=False):
|
|||
|
||||
|
||||
def do_shuffle_map_step():
|
||||
hashFunc = load_function()
|
||||
for obj in read_input():
|
||||
key = PickleSerializer.loads(obj)[1]
|
||||
output(str(hash(key)))
|
||||
key = PickleSerializer.loads(obj)[0]
|
||||
output(str(hashFunc(key)))
|
||||
output(obj)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче