зеркало из https://github.com/dotnet/spark.git
Add microsoft-spark 3.2 jar (#1010)
This commit is contained in:
Родитель
6f82f15c22
Коммит
7bc016f5ed
|
@ -0,0 +1,77 @@
|
|||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>com.microsoft.scala</groupId>
|
||||
<artifactId>microsoft-spark</artifactId>
|
||||
<version>${microsoft-spark.version}</version>
|
||||
</parent>
|
||||
<artifactId>microsoft-spark-3-2_2.12</artifactId>
|
||||
<inceptionYear>2019</inceptionYear>
|
||||
<properties>
|
||||
<encoding>UTF-8</encoding>
|
||||
<scala.version>2.12.10</scala.version>
|
||||
<scala.binary.version>2.12</scala.binary.version>
|
||||
<spark.version>3.2.0</spark.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.1</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.specs</groupId>
|
||||
<artifactId>specs</artifactId>
|
||||
<version>1.2.5</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<sourceDirectory>src/main/scala</sourceDirectory>
|
||||
<testSourceDirectory>src/test/scala</testSourceDirectory>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.scala-tools</groupId>
|
||||
<artifactId>maven-scala-plugin</artifactId>
|
||||
<version>2.15.2</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>compile</goal>
|
||||
<goal>testCompile</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<configuration>
|
||||
<scalaVersion>${scala.version}</scalaVersion>
|
||||
<args>
|
||||
<arg>-target:jvm-1.8</arg>
|
||||
<arg>-deprecation</arg>
|
||||
<arg>-feature</arg>
|
||||
</args>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.io.DataOutputStream
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
|
||||
import scala.collection.mutable.Queue
|
||||
|
||||
/**
|
||||
* CallbackClient is used to communicate with the Dotnet CallbackServer.
|
||||
* The client manages and maintains a pool of open CallbackConnections.
|
||||
* Any callback request is delegated to a new CallbackConnection or
|
||||
* unused CallbackConnection.
|
||||
* @param address The address of the Dotnet CallbackServer
|
||||
* @param port The port of the Dotnet CallbackServer
|
||||
*/
|
||||
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
|
||||
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
|
||||
|
||||
private[this] var isShutdown: Boolean = false
|
||||
|
||||
final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
|
||||
getOrCreateConnection() match {
|
||||
case Some(connection) =>
|
||||
try {
|
||||
connection.send(callbackId, writeBody)
|
||||
addConnection(connection)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Error calling callback [callback id = $callbackId].", e)
|
||||
connection.close()
|
||||
throw e
|
||||
}
|
||||
case None => throw new Exception("Unable to get or create connection.")
|
||||
}
|
||||
|
||||
private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
|
||||
if (isShutdown) {
|
||||
logInfo("Cannot get or create connection while client is shutdown.")
|
||||
return None
|
||||
}
|
||||
|
||||
if (connectionPool.nonEmpty) {
|
||||
return Some(connectionPool.dequeue())
|
||||
}
|
||||
|
||||
Some(new CallbackConnection(serDe, address, port))
|
||||
}
|
||||
|
||||
private def addConnection(connection: CallbackConnection): Unit = synchronized {
|
||||
assert(connection != null)
|
||||
connectionPool.enqueue(connection)
|
||||
}
|
||||
|
||||
def shutdown(): Unit = synchronized {
|
||||
if (isShutdown) {
|
||||
logInfo("Shutdown called, but already shutdown.")
|
||||
return
|
||||
}
|
||||
|
||||
logInfo("Shutting down.")
|
||||
connectionPool.foreach(_.close)
|
||||
connectionPool.clear
|
||||
isShutdown = true
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
|
||||
import java.net.Socket
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
|
||||
/**
|
||||
* CallbackConnection is used to process the callback communication
|
||||
* between the JVM and Dotnet. It uses a TCP socket to communicate with
|
||||
* the Dotnet CallbackServer and the socket is expected to be reused.
|
||||
* @param address The address of the Dotnet CallbackServer
|
||||
* @param port The port of the Dotnet CallbackServer
|
||||
*/
|
||||
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
|
||||
private[this] val socket: Socket = new Socket(address, port)
|
||||
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
|
||||
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
|
||||
|
||||
def send(
|
||||
callbackId: Int,
|
||||
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
|
||||
logInfo(s"Calling callback [callback id = $callbackId] ...")
|
||||
|
||||
try {
|
||||
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
|
||||
serDe.writeInt(outputStream, callbackId)
|
||||
|
||||
val byteArrayOutputStream = new ByteArrayOutputStream()
|
||||
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
|
||||
serDe.writeInt(outputStream, byteArrayOutputStream.size)
|
||||
byteArrayOutputStream.writeTo(outputStream);
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
throw new Exception("Error writing to stream.", e)
|
||||
}
|
||||
}
|
||||
|
||||
logInfo(s"Signaling END_OF_STREAM.")
|
||||
try {
|
||||
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
|
||||
outputStream.flush()
|
||||
|
||||
val endOfStreamResponse = readFlag(inputStream)
|
||||
endOfStreamResponse match {
|
||||
case CallbackFlags.END_OF_STREAM =>
|
||||
logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
|
||||
case _ => {
|
||||
throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
|
||||
s"Received: $endOfStreamResponse")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
throw new Exception("Error while verifying end of stream.", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def close(): Unit = {
|
||||
try {
|
||||
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
|
||||
outputStream.flush()
|
||||
} catch {
|
||||
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
|
||||
}
|
||||
|
||||
close(socket)
|
||||
close(outputStream)
|
||||
close(inputStream)
|
||||
}
|
||||
|
||||
private def close(s: Socket): Unit = {
|
||||
try {
|
||||
assert(s != null)
|
||||
s.close()
|
||||
} catch {
|
||||
case e: Exception => logInfo("Unable to close socket.", e)
|
||||
}
|
||||
}
|
||||
|
||||
private def close(c: Closeable): Unit = {
|
||||
try {
|
||||
assert(c != null)
|
||||
c.close()
|
||||
} catch {
|
||||
case e: Exception => logInfo("Unable to close closeable.", e)
|
||||
}
|
||||
}
|
||||
|
||||
private def readFlag(inputStream: DataInputStream): Int = {
|
||||
val callbackFlag = serDe.readInt(inputStream)
|
||||
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
|
||||
val exceptionMessage = serDe.readString(inputStream)
|
||||
throw new DotnetException(exceptionMessage)
|
||||
}
|
||||
callbackFlag
|
||||
}
|
||||
|
||||
private object CallbackFlags {
|
||||
val CLOSE: Int = -1
|
||||
val CALLBACK: Int = -2
|
||||
val DOTNET_EXCEPTION_THROWN: Int = -3
|
||||
val END_OF_STREAM: Int = -4
|
||||
}
|
||||
}
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
import java.util.concurrent.TimeUnit
|
||||
import io.netty.bootstrap.ServerBootstrap
|
||||
import io.netty.channel.nio.NioEventLoopGroup
|
||||
import io.netty.channel.socket.SocketChannel
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel
|
||||
import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
|
||||
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
|
||||
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
|
||||
import org.apache.spark.{SparkConf, SparkEnv}
|
||||
|
||||
/**
|
||||
* Netty server that invokes JVM calls based upon receiving messages from .NET.
|
||||
* The implementation mirrors the RBackend.
|
||||
*
|
||||
*/
|
||||
class DotnetBackend extends Logging {
|
||||
self => // for accessing the this reference in inner class(ChannelInitializer)
|
||||
private[this] var channelFuture: ChannelFuture = _
|
||||
private[this] var bootstrap: ServerBootstrap = _
|
||||
private[this] var bossGroup: EventLoopGroup = _
|
||||
private[this] val objectTracker = new JVMObjectTracker
|
||||
|
||||
@volatile
|
||||
private[dotnet] var callbackClient: Option[CallbackClient] = None
|
||||
|
||||
def init(portNumber: Int): Int = {
|
||||
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
|
||||
val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
|
||||
logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
|
||||
bossGroup = new NioEventLoopGroup(numBackendThreads)
|
||||
val workerGroup = bossGroup
|
||||
|
||||
bootstrap = new ServerBootstrap()
|
||||
.group(bossGroup, workerGroup)
|
||||
.channel(classOf[NioServerSocketChannel])
|
||||
|
||||
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
|
||||
def initChannel(ch: SocketChannel): Unit = {
|
||||
ch.pipeline()
|
||||
.addLast("encoder", new ByteArrayEncoder())
|
||||
.addLast(
|
||||
"frameDecoder",
|
||||
// maxFrameLength = 2G
|
||||
// lengthFieldOffset = 0
|
||||
// lengthFieldLength = 4
|
||||
// lengthAdjustment = 0
|
||||
// initialBytesToStrip = 4, i.e. strip out the length field itself
|
||||
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
|
||||
.addLast("decoder", new ByteArrayDecoder())
|
||||
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
|
||||
}
|
||||
})
|
||||
|
||||
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
|
||||
channelFuture.syncUninterruptibly()
|
||||
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
|
||||
}
|
||||
|
||||
private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
|
||||
callbackClient = callbackClient match {
|
||||
case Some(_) => throw new Exception("Callback client already set.")
|
||||
case None =>
|
||||
logInfo(s"Connecting to a callback server at $address:$port")
|
||||
Some(new CallbackClient(new SerDe(objectTracker), address, port))
|
||||
}
|
||||
}
|
||||
|
||||
private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
|
||||
callbackClient match {
|
||||
case Some(client) => client.shutdown()
|
||||
case None => logInfo("Callback server has already been shutdown.")
|
||||
}
|
||||
callbackClient = None
|
||||
}
|
||||
|
||||
def run(): Unit = {
|
||||
channelFuture.channel.closeFuture().syncUninterruptibly()
|
||||
}
|
||||
|
||||
def close(): Unit = {
|
||||
if (channelFuture != null) {
|
||||
// close is a local operation and should finish within milliseconds; timeout just to be safe
|
||||
channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
|
||||
channelFuture = null
|
||||
}
|
||||
if (bootstrap != null && bootstrap.config().group() != null) {
|
||||
bootstrap.config().group().shutdownGracefully()
|
||||
}
|
||||
if (bootstrap != null && bootstrap.config().childGroup() != null) {
|
||||
bootstrap.config().childGroup().shutdownGracefully()
|
||||
}
|
||||
bootstrap = null
|
||||
|
||||
objectTracker.clear()
|
||||
|
||||
// Send close to .NET callback server.
|
||||
shutdownCallbackClient()
|
||||
|
||||
// Shutdown the thread pool whose executors could still be running.
|
||||
ThreadPool.shutdown()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,335 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.language.existentials
|
||||
|
||||
/**
|
||||
* Handler for DotnetBackend.
|
||||
* This implementation is similar to RBackendHandler.
|
||||
*/
|
||||
class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker)
|
||||
extends SimpleChannelInboundHandler[Array[Byte]]
|
||||
with Logging {
|
||||
|
||||
private[this] val serDe = new SerDe(objectsTracker)
|
||||
|
||||
override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
|
||||
val reply = handleBackendRequest(msg)
|
||||
ctx.write(reply)
|
||||
}
|
||||
|
||||
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
|
||||
ctx.flush()
|
||||
}
|
||||
|
||||
def handleBackendRequest(msg: Array[Byte]): Array[Byte] = {
|
||||
val bis = new ByteArrayInputStream(msg)
|
||||
val dis = new DataInputStream(bis)
|
||||
|
||||
val bos = new ByteArrayOutputStream()
|
||||
val dos = new DataOutputStream(bos)
|
||||
|
||||
// First bit is isStatic
|
||||
val isStatic = serDe.readBoolean(dis)
|
||||
val processId = serDe.readInt(dis)
|
||||
val threadId = serDe.readInt(dis)
|
||||
val objId = serDe.readString(dis)
|
||||
val methodName = serDe.readString(dis)
|
||||
val numArgs = serDe.readInt(dis)
|
||||
|
||||
if (objId == "DotnetHandler") {
|
||||
methodName match {
|
||||
case "stopBackend" =>
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeType(dos, "void")
|
||||
server.close()
|
||||
case "rm" =>
|
||||
try {
|
||||
val t = serDe.readObjectType(dis)
|
||||
assert(t == 'c')
|
||||
val objToRemove = serDe.readString(dis)
|
||||
objectsTracker.remove(objToRemove)
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, null)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Removing $objId failed", e)
|
||||
serDe.writeInt(dos, -1)
|
||||
}
|
||||
case "rmThread" =>
|
||||
try {
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val processId = serDe.readInt(dis)
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val threadToDelete = serDe.readInt(dis)
|
||||
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"Removing thread $threadId failed", e)
|
||||
serDe.writeInt(dos, -1)
|
||||
}
|
||||
case "connectCallback" =>
|
||||
assert(serDe.readObjectType(dis) == 'c')
|
||||
val address = serDe.readString(dis)
|
||||
assert(serDe.readObjectType(dis) == 'i')
|
||||
val port = serDe.readInt(dis)
|
||||
server.setCallbackClient(address, port)
|
||||
serDe.writeInt(dos, 0)
|
||||
|
||||
// Sends reference of CallbackClient to dotnet side,
|
||||
// so that dotnet process can send the client back to Java side
|
||||
// when calling any API containing callback functions.
|
||||
serDe.writeObject(dos, server.callbackClient)
|
||||
case "closeCallback" =>
|
||||
logInfo("Requesting to close callback client")
|
||||
server.shutdownCallbackClient()
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeType(dos, "void")
|
||||
case _ => dos.writeInt(-1)
|
||||
}
|
||||
} else {
|
||||
ThreadPool
|
||||
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
|
||||
}
|
||||
|
||||
bos.toByteArray
|
||||
}
|
||||
|
||||
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
|
||||
// Skip logging the exception message if the connection was disconnected from
|
||||
// the .NET side so that .NET side doesn't have to explicitly close the connection via
|
||||
// "stopBackend." Note that an exception is still thrown if the exit status is non-zero,
|
||||
// so skipping this kind of exception message does not affect the debugging.
|
||||
if (!cause.getMessage.contains(
|
||||
"An existing connection was forcibly closed by the remote host")) {
|
||||
logError("Exception caught: ", cause)
|
||||
}
|
||||
|
||||
// Close the connection when an exception is raised.
|
||||
ctx.close()
|
||||
}
|
||||
|
||||
def handleMethodCall(
|
||||
isStatic: Boolean,
|
||||
objId: String,
|
||||
methodName: String,
|
||||
numArgs: Int,
|
||||
dis: DataInputStream,
|
||||
dos: DataOutputStream): Unit = {
|
||||
var obj: Object = null
|
||||
var args: Array[java.lang.Object] = null
|
||||
var methods: Array[java.lang.reflect.Method] = null
|
||||
|
||||
try {
|
||||
val cls = if (isStatic) {
|
||||
Utils.classForName(objId)
|
||||
} else {
|
||||
objectsTracker.get(objId) match {
|
||||
case None => throw new IllegalArgumentException("Object not found " + objId)
|
||||
case Some(o) =>
|
||||
obj = o
|
||||
o.getClass
|
||||
}
|
||||
}
|
||||
|
||||
args = readArgs(numArgs, dis)
|
||||
methods = cls.getMethods
|
||||
|
||||
val selectedMethods = methods.filter(m => m.getName == methodName)
|
||||
if (selectedMethods.length > 0) {
|
||||
val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args)
|
||||
|
||||
if (index.isEmpty) {
|
||||
logWarning(
|
||||
s"cannot find matching method ${cls}.$methodName. "
|
||||
+ s"Candidates are:")
|
||||
selectedMethods.foreach { method =>
|
||||
logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
|
||||
}
|
||||
throw new Exception(s"No matched method found for $cls.$methodName")
|
||||
}
|
||||
|
||||
val ret = selectedMethods(index.get).invoke(obj, args: _*)
|
||||
|
||||
// Write status bit
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, ret.asInstanceOf[AnyRef])
|
||||
} else if (methodName == "<init>") {
|
||||
// methodName should be "<init>" for constructor
|
||||
val ctor = cls.getConstructors.filter { x =>
|
||||
matchMethod(numArgs, args, x.getParameterTypes)
|
||||
}.head
|
||||
|
||||
val obj = ctor.newInstance(args: _*)
|
||||
|
||||
serDe.writeInt(dos, 0)
|
||||
serDe.writeObject(dos, obj.asInstanceOf[AnyRef])
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"invalid method " + methodName + " for object " + objId)
|
||||
}
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
val jvmObj = objectsTracker.get(objId)
|
||||
val jvmObjName = jvmObj match {
|
||||
case Some(jObj) => jObj.getClass.getName
|
||||
case None => "NullObject"
|
||||
}
|
||||
val argsStr = args
|
||||
.map(arg => {
|
||||
if (arg != null) {
|
||||
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
|
||||
} else {
|
||||
"[Value: NULL]"
|
||||
}
|
||||
})
|
||||
.mkString(", ")
|
||||
|
||||
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
|
||||
|
||||
if (methods != null) {
|
||||
logDebug(s"All methods for $jvmObjName:")
|
||||
methods.foreach(m => logDebug(m.toString))
|
||||
}
|
||||
|
||||
serDe.writeInt(dos, -1)
|
||||
serDe.writeString(dos, Utils.exceptionString(e.getCause))
|
||||
}
|
||||
}
|
||||
|
||||
// Read a number of arguments from the data input stream
|
||||
def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
|
||||
(0 until numArgs).map { arg =>
|
||||
serDe.readObject(dis)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
// Checks if the arguments passed in args matches the parameter types.
|
||||
// NOTE: Currently we do exact match. We may add type conversions later.
|
||||
def matchMethod(
|
||||
numArgs: Int,
|
||||
args: Array[java.lang.Object],
|
||||
parameterTypes: Array[Class[_]]): Boolean = {
|
||||
if (parameterTypes.length != numArgs) {
|
||||
return false
|
||||
}
|
||||
|
||||
for (i <- 0 until numArgs) {
|
||||
val parameterType = parameterTypes(i)
|
||||
var parameterWrapperType = parameterType
|
||||
|
||||
// Convert native parameters to Object types as args is Array[Object] here
|
||||
if (parameterType.isPrimitive) {
|
||||
parameterWrapperType = parameterType match {
|
||||
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
|
||||
case java.lang.Long.TYPE => classOf[java.lang.Long]
|
||||
case java.lang.Double.TYPE => classOf[java.lang.Double]
|
||||
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
|
||||
case _ => parameterType
|
||||
}
|
||||
}
|
||||
|
||||
if (!parameterWrapperType.isInstance(args(i))) {
|
||||
// non primitive types
|
||||
if (!parameterType.isPrimitive && args(i) != null) {
|
||||
return false
|
||||
}
|
||||
|
||||
// primitive types
|
||||
if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// Find a matching method signature in an array of signatures of constructors
|
||||
// or methods of the same name according to the passed arguments. Arguments
|
||||
// may be converted in order to match a signature.
|
||||
//
|
||||
// Note that in Java reflection, constructors and normal methods are of different
|
||||
// classes, and share no parent class that provides methods for reflection uses.
|
||||
// There is no unified way to handle them in this function. So an array of signatures
|
||||
// is passed in instead of an array of candidate constructors or methods.
|
||||
//
|
||||
// Returns an Option[Int] which is the index of the matched signature in the array.
|
||||
def findMatchedSignature(
|
||||
parameterTypesOfMethods: Array[Array[Class[_]]],
|
||||
args: Array[Object]): Option[Int] = {
|
||||
val numArgs = args.length
|
||||
|
||||
for (index <- parameterTypesOfMethods.indices) {
|
||||
val parameterTypes = parameterTypesOfMethods(index)
|
||||
|
||||
if (parameterTypes.length == numArgs) {
|
||||
var argMatched = true
|
||||
var i = 0
|
||||
while (i < numArgs && argMatched) {
|
||||
val parameterType = parameterTypes(i)
|
||||
|
||||
if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
|
||||
// The case that the parameter type is a Scala Seq and the argument
|
||||
// is a Java array is considered matching. The array will be converted
|
||||
// to a Seq later if this method is matched.
|
||||
} else {
|
||||
var parameterWrapperType = parameterType
|
||||
|
||||
// Convert native parameters to Object types as args is Array[Object] here
|
||||
if (parameterType.isPrimitive) {
|
||||
parameterWrapperType = parameterType match {
|
||||
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
|
||||
case java.lang.Long.TYPE => classOf[java.lang.Long]
|
||||
case java.lang.Double.TYPE => classOf[java.lang.Double]
|
||||
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
|
||||
case _ => parameterType
|
||||
}
|
||||
}
|
||||
if ((parameterType.isPrimitive || args(i) != null) &&
|
||||
!parameterWrapperType.isInstance(args(i))) {
|
||||
argMatched = false
|
||||
}
|
||||
}
|
||||
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
if (argMatched) {
|
||||
// For now, we return the first matching method.
|
||||
// TODO: find best method in matching methods.
|
||||
|
||||
// Convert args if needed
|
||||
val parameterTypes = parameterTypesOfMethods(index)
|
||||
|
||||
for (i <- 0 until numArgs) {
|
||||
if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
|
||||
// Convert a Java array to scala Seq
|
||||
args(i) = args(i).asInstanceOf[Array[_]].toSeq
|
||||
}
|
||||
}
|
||||
|
||||
return Some(index)
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def logError(id: String, e: Exception): Unit = {}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
class DotnetException(message: String, cause: Throwable)
|
||||
extends Exception(message, cause) {
|
||||
|
||||
def this(message: String) = this(message, null)
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.python._
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
object DotnetRDD {
|
||||
def createPythonRDD(
|
||||
parent: RDD[_],
|
||||
func: PythonFunction,
|
||||
preservePartitoning: Boolean): PythonRDD = {
|
||||
new PythonRDD(parent, func, preservePartitoning)
|
||||
}
|
||||
|
||||
def createJavaRDDFromArray(
|
||||
sc: SparkContext,
|
||||
arr: Array[Array[Byte]],
|
||||
numSlices: Int): JavaRDD[Array[Byte]] = {
|
||||
JavaRDD.fromRDD(sc.parallelize(arr, numSlices))
|
||||
}
|
||||
|
||||
def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd)
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
/**
|
||||
* Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects.
|
||||
*/
|
||||
private[dotnet] class JVMObjectTracker {
|
||||
|
||||
// Multiple threads may access objMap and increase objCounter. Because get method return Option,
|
||||
// it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap.
|
||||
private[this] val objMap = new HashMap[String, Object]
|
||||
private[this] var objCounter: Int = 1
|
||||
|
||||
def getObject(id: String): Object = {
|
||||
synchronized {
|
||||
objMap(id)
|
||||
}
|
||||
}
|
||||
|
||||
def get(id: String): Option[Object] = {
|
||||
synchronized {
|
||||
objMap.get(id)
|
||||
}
|
||||
}
|
||||
|
||||
def put(obj: Object): String = {
|
||||
synchronized {
|
||||
val objId = objCounter.toString
|
||||
objCounter = objCounter + 1
|
||||
objMap.put(objId, obj)
|
||||
objId
|
||||
}
|
||||
}
|
||||
|
||||
def remove(id: String): Option[Object] = {
|
||||
synchronized {
|
||||
objMap.remove(id)
|
||||
}
|
||||
}
|
||||
|
||||
def clear(): Unit = {
|
||||
synchronized {
|
||||
objMap.clear()
|
||||
objCounter = 1
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.dotnet
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
|
||||
/*
|
||||
* Utils for JvmBridge.
|
||||
*/
|
||||
object JvmBridgeUtils {
|
||||
def getKeyValuePairAsString(kvp: (String, String)): String = {
|
||||
return kvp._1 + "=" + kvp._2
|
||||
}
|
||||
|
||||
def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = {
|
||||
val sb = new StringBuilder
|
||||
|
||||
for (kvp <- kvpArray) {
|
||||
sb.append(getKeyValuePairAsString(kvp))
|
||||
sb.append(";")
|
||||
}
|
||||
|
||||
sb.toString
|
||||
}
|
||||
|
||||
def getSparkConfAsString(sparkConf: SparkConf): String = {
|
||||
getKeyValuePairArrayAsString(sparkConf.getAll)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,387 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.io.{DataInputStream, DataOutputStream}
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.sql.{Date, Time, Timestamp}
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
* Class responsible for serialization and deserialization between CLR & JVM.
|
||||
* This implementation of methods is mostly identical to the SerDe implementation in R.
|
||||
*/
|
||||
class SerDe(val tracker: JVMObjectTracker) {
|
||||
|
||||
def readObjectType(dis: DataInputStream): Char = {
|
||||
dis.readByte().toChar
|
||||
}
|
||||
|
||||
def readObject(dis: DataInputStream): Object = {
|
||||
val dataType = readObjectType(dis)
|
||||
readTypedObject(dis, dataType)
|
||||
}
|
||||
|
||||
private def readTypedObject(dis: DataInputStream, dataType: Char): Object = {
|
||||
dataType match {
|
||||
case 'n' => null
|
||||
case 'i' => new java.lang.Integer(readInt(dis))
|
||||
case 'g' => new java.lang.Long(readLong(dis))
|
||||
case 'd' => new java.lang.Double(readDouble(dis))
|
||||
case 'b' => new java.lang.Boolean(readBoolean(dis))
|
||||
case 'c' => readString(dis)
|
||||
case 'e' => readMap(dis)
|
||||
case 'r' => readBytes(dis)
|
||||
case 'l' => readList(dis)
|
||||
case 'D' => readDate(dis)
|
||||
case 't' => readTime(dis)
|
||||
case 'j' => tracker.getObject(readString(dis))
|
||||
case 'R' => readRowArr(dis)
|
||||
case 'O' => readObjectArr(dis)
|
||||
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
|
||||
}
|
||||
}
|
||||
|
||||
private def readBytes(in: DataInputStream): Array[Byte] = {
|
||||
val len = readInt(in)
|
||||
val out = new Array[Byte](len)
|
||||
in.readFully(out)
|
||||
out
|
||||
}
|
||||
|
||||
def readInt(in: DataInputStream): Int = {
|
||||
in.readInt()
|
||||
}
|
||||
|
||||
private def readLong(in: DataInputStream): Long = {
|
||||
in.readLong()
|
||||
}
|
||||
|
||||
private def readDouble(in: DataInputStream): Double = {
|
||||
in.readDouble()
|
||||
}
|
||||
|
||||
private def readStringBytes(in: DataInputStream, len: Int): String = {
|
||||
val bytes = new Array[Byte](len)
|
||||
in.readFully(bytes)
|
||||
val str = new String(bytes, "UTF-8")
|
||||
str
|
||||
}
|
||||
|
||||
def readString(in: DataInputStream): String = {
|
||||
val len = in.readInt()
|
||||
readStringBytes(in, len)
|
||||
}
|
||||
|
||||
def readBoolean(in: DataInputStream): Boolean = {
|
||||
in.readBoolean()
|
||||
}
|
||||
|
||||
private def readDate(in: DataInputStream): Date = {
|
||||
Date.valueOf(readString(in))
|
||||
}
|
||||
|
||||
private def readTime(in: DataInputStream): Timestamp = {
|
||||
val seconds = in.readDouble()
|
||||
val sec = Math.floor(seconds).toLong
|
||||
val t = new Timestamp(sec * 1000L)
|
||||
t.setNanos(((seconds - sec) * 1e9).toInt)
|
||||
t
|
||||
}
|
||||
|
||||
private def readRow(in: DataInputStream): Row = {
|
||||
val len = readInt(in)
|
||||
Row.fromSeq((0 until len).map(_ => readObject(in)))
|
||||
}
|
||||
|
||||
private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readBytes(in)).toArray
|
||||
}
|
||||
|
||||
private def readIntArr(in: DataInputStream): Array[Int] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readInt(in)).toArray
|
||||
}
|
||||
|
||||
private def readLongArr(in: DataInputStream): Array[Long] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readLong(in)).toArray
|
||||
}
|
||||
|
||||
private def readDoubleArr(in: DataInputStream): Array[Double] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readDouble(in)).toArray
|
||||
}
|
||||
|
||||
private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readDoubleArr(in)).toArray
|
||||
}
|
||||
|
||||
private def readBooleanArr(in: DataInputStream): Array[Boolean] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readBoolean(in)).toArray
|
||||
}
|
||||
|
||||
private def readStringArr(in: DataInputStream): Array[String] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readString(in)).toArray
|
||||
}
|
||||
|
||||
private def readRowArr(in: DataInputStream): java.util.List[Row] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readRow(in)).toList.asJava
|
||||
}
|
||||
|
||||
private def readObjectArr(in: DataInputStream): Seq[Any] = {
|
||||
val len = readInt(in)
|
||||
(0 until len).map(_ => readObject(in))
|
||||
}
|
||||
|
||||
private def readList(dis: DataInputStream): Array[_] = {
|
||||
val arrType = readObjectType(dis)
|
||||
arrType match {
|
||||
case 'i' => readIntArr(dis)
|
||||
case 'g' => readLongArr(dis)
|
||||
case 'c' => readStringArr(dis)
|
||||
case 'd' => readDoubleArr(dis)
|
||||
case 'A' => readDoubleArrArr(dis)
|
||||
case 'b' => readBooleanArr(dis)
|
||||
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
|
||||
case 'r' => readBytesArr(dis)
|
||||
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
|
||||
}
|
||||
}
|
||||
|
||||
private def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
|
||||
val len = readInt(in)
|
||||
if (len > 0) {
|
||||
val keysType = readObjectType(in)
|
||||
val keysLen = readInt(in)
|
||||
val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
|
||||
|
||||
val valuesLen = readInt(in)
|
||||
val values = (0 until valuesLen).map(_ => {
|
||||
val valueType = readObjectType(in)
|
||||
readTypedObject(in, valueType)
|
||||
})
|
||||
keys.zip(values).toMap.asJava
|
||||
} else {
|
||||
new java.util.HashMap[Object, Object]()
|
||||
}
|
||||
}
|
||||
|
||||
// Using the same mapping as SparkR implementation for now
|
||||
// Methods to write out data from Java to .NET.
|
||||
//
|
||||
// Type mapping from Java to .NET:
|
||||
//
|
||||
// void -> NULL
|
||||
// Int -> integer
|
||||
// String -> character
|
||||
// Boolean -> logical
|
||||
// Float -> double
|
||||
// Double -> double
|
||||
// Long -> long
|
||||
// Array[Byte] -> raw
|
||||
// Date -> Date
|
||||
// Time -> POSIXct
|
||||
//
|
||||
// Array[T] -> list()
|
||||
// Object -> jobj
|
||||
|
||||
def writeType(dos: DataOutputStream, typeStr: String): Unit = {
|
||||
typeStr match {
|
||||
case "void" => dos.writeByte('n')
|
||||
case "character" => dos.writeByte('c')
|
||||
case "double" => dos.writeByte('d')
|
||||
case "doublearray" => dos.writeByte('A')
|
||||
case "long" => dos.writeByte('g')
|
||||
case "integer" => dos.writeByte('i')
|
||||
case "logical" => dos.writeByte('b')
|
||||
case "date" => dos.writeByte('D')
|
||||
case "time" => dos.writeByte('t')
|
||||
case "raw" => dos.writeByte('r')
|
||||
case "list" => dos.writeByte('l')
|
||||
case "jobj" => dos.writeByte('j')
|
||||
case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
|
||||
}
|
||||
}
|
||||
|
||||
def writeObject(dos: DataOutputStream, value: Object): Unit = {
|
||||
if (value == null || value == Unit) {
|
||||
writeType(dos, "void")
|
||||
} else {
|
||||
value.getClass.getName match {
|
||||
case "java.lang.String" =>
|
||||
writeType(dos, "character")
|
||||
writeString(dos, value.asInstanceOf[String])
|
||||
case "float" | "java.lang.Float" =>
|
||||
writeType(dos, "double")
|
||||
writeDouble(dos, value.asInstanceOf[Float].toDouble)
|
||||
case "double" | "java.lang.Double" =>
|
||||
writeType(dos, "double")
|
||||
writeDouble(dos, value.asInstanceOf[Double])
|
||||
case "long" | "java.lang.Long" =>
|
||||
writeType(dos, "long")
|
||||
writeLong(dos, value.asInstanceOf[Long])
|
||||
case "int" | "java.lang.Integer" =>
|
||||
writeType(dos, "integer")
|
||||
writeInt(dos, value.asInstanceOf[Int])
|
||||
case "boolean" | "java.lang.Boolean" =>
|
||||
writeType(dos, "logical")
|
||||
writeBoolean(dos, value.asInstanceOf[Boolean])
|
||||
case "java.sql.Date" =>
|
||||
writeType(dos, "date")
|
||||
writeDate(dos, value.asInstanceOf[Date])
|
||||
case "java.sql.Time" =>
|
||||
writeType(dos, "time")
|
||||
writeTime(dos, value.asInstanceOf[Time])
|
||||
case "java.sql.Timestamp" =>
|
||||
writeType(dos, "time")
|
||||
writeTime(dos, value.asInstanceOf[Timestamp])
|
||||
case "[B" =>
|
||||
writeType(dos, "raw")
|
||||
writeBytes(dos, value.asInstanceOf[Array[Byte]])
|
||||
// TODO: Types not handled right now include
|
||||
// byte, char, short, float
|
||||
|
||||
// Handle arrays
|
||||
case "[Ljava.lang.String;" =>
|
||||
writeType(dos, "list")
|
||||
writeStringArr(dos, value.asInstanceOf[Array[String]])
|
||||
case "[I" =>
|
||||
writeType(dos, "list")
|
||||
writeIntArr(dos, value.asInstanceOf[Array[Int]])
|
||||
case "[J" =>
|
||||
writeType(dos, "list")
|
||||
writeLongArr(dos, value.asInstanceOf[Array[Long]])
|
||||
case "[D" =>
|
||||
writeType(dos, "list")
|
||||
writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
|
||||
case "[[D" =>
|
||||
writeType(dos, "list")
|
||||
writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]])
|
||||
case "[Z" =>
|
||||
writeType(dos, "list")
|
||||
writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
|
||||
case "[[B" =>
|
||||
writeType(dos, "list")
|
||||
writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
|
||||
case otherName =>
|
||||
// Handle array of objects
|
||||
if (otherName.startsWith("[L")) {
|
||||
val objArr = value.asInstanceOf[Array[Object]]
|
||||
writeType(dos, "list")
|
||||
writeType(dos, "jobj")
|
||||
dos.writeInt(objArr.length)
|
||||
objArr.foreach(o => writeJObj(dos, o))
|
||||
} else {
|
||||
writeType(dos, "jobj")
|
||||
writeJObj(dos, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def writeInt(out: DataOutputStream, value: Int): Unit = {
|
||||
out.writeInt(value)
|
||||
}
|
||||
|
||||
def writeLong(out: DataOutputStream, value: Long): Unit = {
|
||||
out.writeLong(value)
|
||||
}
|
||||
|
||||
private def writeDouble(out: DataOutputStream, value: Double): Unit = {
|
||||
out.writeDouble(value)
|
||||
}
|
||||
|
||||
private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
|
||||
out.writeBoolean(value)
|
||||
}
|
||||
|
||||
private def writeDate(out: DataOutputStream, value: Date): Unit = {
|
||||
writeString(out, value.toString)
|
||||
}
|
||||
|
||||
private def writeTime(out: DataOutputStream, value: Time): Unit = {
|
||||
out.writeDouble(value.getTime.toDouble / 1000.0)
|
||||
}
|
||||
|
||||
private def writeTime(out: DataOutputStream, value: Timestamp): Unit = {
|
||||
out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
|
||||
}
|
||||
|
||||
def writeString(out: DataOutputStream, value: String): Unit = {
|
||||
val utf8 = value.getBytes(StandardCharsets.UTF_8)
|
||||
val len = utf8.length
|
||||
out.writeInt(len)
|
||||
out.write(utf8, 0, len)
|
||||
}
|
||||
|
||||
private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
|
||||
out.writeInt(value.length)
|
||||
out.write(value)
|
||||
}
|
||||
|
||||
def writeJObj(out: DataOutputStream, value: Object): Unit = {
|
||||
val objId = tracker.put(value)
|
||||
writeString(out, objId)
|
||||
}
|
||||
|
||||
private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
|
||||
writeType(out, "integer")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => out.writeInt(v))
|
||||
}
|
||||
|
||||
private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = {
|
||||
writeType(out, "long")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => out.writeLong(v))
|
||||
}
|
||||
|
||||
private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
|
||||
writeType(out, "double")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => out.writeDouble(v))
|
||||
}
|
||||
|
||||
private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = {
|
||||
writeType(out, "doublearray")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => writeDoubleArr(out, v))
|
||||
}
|
||||
|
||||
private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
|
||||
writeType(out, "logical")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => writeBoolean(out, v))
|
||||
}
|
||||
|
||||
private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
|
||||
writeType(out, "character")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => writeString(out, v))
|
||||
}
|
||||
|
||||
private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
|
||||
writeType(out, "raw")
|
||||
out.writeInt(value.length)
|
||||
value.foreach(v => writeBytes(out, v))
|
||||
}
|
||||
}
|
||||
|
||||
private object SerializationFormats {
|
||||
val BYTE = "byte"
|
||||
val STRING = "string"
|
||||
val ROW = "row"
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.util.concurrent.{ExecutorService, Executors}
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
/**
|
||||
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
|
||||
* and Java threads.
|
||||
*/
|
||||
object ThreadPool {
|
||||
|
||||
/**
|
||||
* Map from (processId, threadId) to corresponding executor.
|
||||
*/
|
||||
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
|
||||
new mutable.HashMap[(Int, Int), ExecutorService]()
|
||||
|
||||
/**
|
||||
* Run some code on a particular thread.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @param task Function to run on the thread.
|
||||
*/
|
||||
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
|
||||
val executor = getOrCreateExecutor(processId, threadId)
|
||||
val future = executor.submit(new Runnable {
|
||||
override def run(): Unit = task()
|
||||
})
|
||||
|
||||
future.get()
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to delete a particular thread.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return True if successful, false if thread does not exist.
|
||||
*/
|
||||
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
|
||||
executors.remove((processId, threadId)) match {
|
||||
case Some(executorService) =>
|
||||
executorService.shutdown()
|
||||
true
|
||||
case None => false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown any running ExecutorServices.
|
||||
*/
|
||||
def shutdown(): Unit = synchronized {
|
||||
executors.foreach(_._2.shutdown())
|
||||
executors.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the executor if it exists, otherwise create a new one.
|
||||
* @param processId Integer id of the process.
|
||||
* @param threadId Integer id of the thread.
|
||||
* @return The new or existing executor with the given id.
|
||||
*/
|
||||
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
|
||||
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,285 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.dotnet
|
||||
|
||||
import java.io.File
|
||||
import java.net.URI
|
||||
import java.nio.file.attribute.PosixFilePermissions
|
||||
import java.nio.file.{FileSystems, Files, Paths}
|
||||
import java.util.Locale
|
||||
import java.util.concurrent.{Semaphore, TimeUnit}
|
||||
|
||||
import org.apache.commons.io.FilenameUtils
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark
|
||||
import org.apache.spark.api.dotnet.DotnetBackend
|
||||
import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
|
||||
import org.apache.spark.util.dotnet.{Utils => DotnetUtils}
|
||||
import org.apache.spark.util.{RedirectThread, Utils}
|
||||
import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.io.StdIn
|
||||
import scala.util.Try
|
||||
|
||||
/**
|
||||
* DotnetRunner class used to launch Spark .NET applications using spark-submit.
|
||||
* It executes .NET application as a subprocess and then has it connect back to
|
||||
* the JVM to access system properties etc.
|
||||
*/
|
||||
object DotnetRunner extends Logging {
|
||||
private val DEBUG_PORT = 5567
|
||||
private val supportedSparkMajorMinorVersionPrefix = "3.2"
|
||||
private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1")
|
||||
|
||||
val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION)
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
if (args.length == 0) {
|
||||
throw new IllegalArgumentException("At least one argument is expected.")
|
||||
}
|
||||
|
||||
DotnetUtils.validateSparkVersions(
|
||||
sys.props
|
||||
.getOrElse(
|
||||
DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key,
|
||||
DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString)
|
||||
.toBoolean,
|
||||
spark.SPARK_VERSION,
|
||||
SPARK_VERSION,
|
||||
supportedSparkMajorMinorVersionPrefix,
|
||||
supportedSparkVersions)
|
||||
|
||||
val settings = initializeSettings(args)
|
||||
|
||||
// Determines if this needs to be run in debug mode.
|
||||
// In debug mode this runner will not launch a .NET process.
|
||||
val runInDebugMode = settings._1
|
||||
@volatile var dotnetBackendPortNumber = settings._2
|
||||
var dotnetExecutable = ""
|
||||
var otherArgs: Array[String] = null
|
||||
|
||||
if (!runInDebugMode) {
|
||||
if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) {
|
||||
var zipFileName = args(0)
|
||||
val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI)
|
||||
val workingDir = new File("").getAbsoluteFile
|
||||
val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath()))
|
||||
|
||||
// Standalone cluster mode where .NET application is remotely located.
|
||||
if (zipFileUri.getScheme() != "file") {
|
||||
zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName
|
||||
}
|
||||
|
||||
logInfo(s"Unzipping .NET driver $zipFileName to $driverDir")
|
||||
DotnetUtils.unzip(new File(zipFileName), driverDir)
|
||||
|
||||
// Reuse windows-specific formatting in PythonRunner.
|
||||
dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1)))
|
||||
otherArgs = args.slice(2, args.length)
|
||||
} else {
|
||||
// Reuse windows-specific formatting in PythonRunner.
|
||||
dotnetExecutable = PythonRunner.formatPath(args(0))
|
||||
otherArgs = args.slice(1, args.length)
|
||||
}
|
||||
} else {
|
||||
otherArgs = args.slice(1, args.length)
|
||||
}
|
||||
|
||||
val processParameters = new java.util.ArrayList[String]
|
||||
processParameters.add(dotnetExecutable)
|
||||
otherArgs.foreach(arg => processParameters.add(arg))
|
||||
|
||||
logInfo(s"Starting DotnetBackend with $dotnetExecutable.")
|
||||
|
||||
// Time to wait for DotnetBackend to initialize in seconds.
|
||||
val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt
|
||||
|
||||
// Launch a DotnetBackend server for the .NET process to connect to; this will let it see our
|
||||
// Java system properties etc.
|
||||
val dotnetBackend = new DotnetBackend()
|
||||
val initialized = new Semaphore(0)
|
||||
val dotnetBackendThread = new Thread("DotnetBackend") {
|
||||
override def run() {
|
||||
// need to get back dotnetBackendPortNumber because if the value passed to init is 0
|
||||
// the port number is dynamically assigned in the backend
|
||||
dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber)
|
||||
logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber")
|
||||
initialized.release()
|
||||
dotnetBackend.run()
|
||||
}
|
||||
}
|
||||
|
||||
dotnetBackendThread.start()
|
||||
|
||||
if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
|
||||
if (!runInDebugMode) {
|
||||
var returnCode = -1
|
||||
var process: Process = null
|
||||
try {
|
||||
val builder = new ProcessBuilder(processParameters)
|
||||
val env = builder.environment()
|
||||
env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString)
|
||||
|
||||
for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
|
||||
env.put(key, value)
|
||||
logInfo(s"Adding key=$key and value=$value to environment")
|
||||
}
|
||||
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
|
||||
process = builder.start()
|
||||
|
||||
// Redirect stdin of JVM process to stdin of .NET process.
|
||||
new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start()
|
||||
// Redirect stdout and stderr of .NET process.
|
||||
new RedirectThread(process.getInputStream, System.out, "redirect .NET stdout").start()
|
||||
new RedirectThread(process.getErrorStream, System.out, "redirect .NET stderr").start()
|
||||
|
||||
process.waitFor()
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
logThrowable(t)
|
||||
} finally {
|
||||
returnCode = closeDotnetProcess(process)
|
||||
closeBackend(dotnetBackend)
|
||||
}
|
||||
|
||||
if (returnCode != 0) {
|
||||
throw new SparkUserAppException(returnCode)
|
||||
} else {
|
||||
logInfo(s".NET application exited successfully")
|
||||
}
|
||||
// TODO: The following is causing the following error:
|
||||
// INFO ApplicationMaster: Final app status: FAILED, exitCode: 16,
|
||||
// (reason: Shutdown hook called before final status was reported.)
|
||||
// DotnetUtils.exit(returnCode)
|
||||
} else {
|
||||
// scalastyle:off println
|
||||
println("***********************************************************************")
|
||||
println("* .NET Backend running debug mode. Press enter to exit *")
|
||||
println("***********************************************************************")
|
||||
// scalastyle:on println
|
||||
|
||||
StdIn.readLine()
|
||||
closeBackend(dotnetBackend)
|
||||
DotnetUtils.exit(0)
|
||||
}
|
||||
} else {
|
||||
logError(s"DotnetBackend did not initialize in $backendTimeout seconds")
|
||||
DotnetUtils.exit(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// When the executable is downloaded as part of zip file, check if the file exists
|
||||
// after zip file is unzipped under the given dir. Once it is found, change the
|
||||
// permission to executable (only for Unix systems, since the zip file may have been
|
||||
// created under Windows. Finally, the absolute path for the executable is returned.
|
||||
private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = {
|
||||
val path = Paths.get(dir.getAbsolutePath, dotnetExecutable)
|
||||
val resolvedExecutable = if (Files.isRegularFile(path)) {
|
||||
path.toAbsolutePath.toString
|
||||
} else {
|
||||
Files
|
||||
.walk(FileSystems.getDefault.getPath(dir.getAbsolutePath))
|
||||
.iterator()
|
||||
.asScala
|
||||
.find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match {
|
||||
case Some(path) => path.toAbsolutePath.toString
|
||||
case None =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}")
|
||||
}
|
||||
}
|
||||
|
||||
if (DotnetUtils.supportPosix) {
|
||||
Files.setPosixFilePermissions(
|
||||
Paths.get(resolvedExecutable),
|
||||
PosixFilePermissions.fromString("rwxr-xr-x"))
|
||||
}
|
||||
|
||||
resolvedExecutable
|
||||
}
|
||||
|
||||
/**
|
||||
* Download HDFS file into the supplied directory and return its local path.
|
||||
* Will throw an exception if there are errors during downloading.
|
||||
*/
|
||||
private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = {
|
||||
val sparkConf = new SparkConf()
|
||||
val filePath = new Path(hdfsFilePath)
|
||||
|
||||
val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
|
||||
val jarFileName = filePath.getName
|
||||
val localFile = new File(driverDir, jarFileName)
|
||||
|
||||
if (!localFile.exists()) { // May already exist if running multiple workers on one node
|
||||
logInfo(s"Copying user file $filePath to $driverDir")
|
||||
Utils.fetchFile(
|
||||
hdfsFilePath,
|
||||
new File(driverDir),
|
||||
sparkConf,
|
||||
hadoopConf,
|
||||
System.currentTimeMillis(),
|
||||
useCache = false)
|
||||
}
|
||||
|
||||
if (!localFile.exists()) {
|
||||
throw new Exception(s"Did not see expected $jarFileName in $driverDir")
|
||||
}
|
||||
|
||||
localFile
|
||||
}
|
||||
|
||||
private def closeBackend(dotnetBackend: DotnetBackend): Unit = {
|
||||
logInfo("Closing DotnetBackend")
|
||||
dotnetBackend.close()
|
||||
}
|
||||
|
||||
private def closeDotnetProcess(dotnetProcess: Process): Int = {
|
||||
if (dotnetProcess == null) {
|
||||
return -1
|
||||
} else if (!dotnetProcess.isAlive) {
|
||||
return dotnetProcess.exitValue()
|
||||
}
|
||||
|
||||
// Try to (gracefully on Linux) kill the process and resort to force if interrupted
|
||||
var returnCode = -1
|
||||
logInfo("Closing .NET process")
|
||||
try {
|
||||
dotnetProcess.destroy()
|
||||
returnCode = dotnetProcess.waitFor()
|
||||
} catch {
|
||||
case _: InterruptedException =>
|
||||
logInfo(
|
||||
"Thread interrupted while waiting for graceful close. Forcefully closing .NET process")
|
||||
returnCode = dotnetProcess.destroyForcibly().waitFor()
|
||||
case t: Throwable =>
|
||||
logThrowable(t)
|
||||
}
|
||||
|
||||
returnCode
|
||||
}
|
||||
|
||||
private def initializeSettings(args: Array[String]): (Boolean, Int) = {
|
||||
val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase(
|
||||
"debug")
|
||||
var portNumber = 0
|
||||
if (runInDebugMode) {
|
||||
if (args.length == 1) {
|
||||
portNumber = DEBUG_PORT
|
||||
} else if (args.length == 2) {
|
||||
portNumber = Integer.parseInt(args(1))
|
||||
}
|
||||
}
|
||||
|
||||
(runInDebugMode, portNumber)
|
||||
}
|
||||
|
||||
private def logThrowable(throwable: Throwable): Unit =
|
||||
logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}")
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.internal.config.dotnet
|
||||
|
||||
import org.apache.spark.internal.config.ConfigBuilder
|
||||
|
||||
private[spark] object Dotnet {
|
||||
val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf
|
||||
.createWithDefault(10)
|
||||
|
||||
val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK =
|
||||
ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf
|
||||
.createWithDefault(false)
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.dotnet
|
||||
|
||||
import org.apache.spark.api.dotnet.CallbackClient
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.streaming.DataStreamWriter
|
||||
|
||||
class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging {
|
||||
def call(batchDF: DataFrame, batchId: Long): Unit =
|
||||
callbackClient.send(
|
||||
callbackId,
|
||||
(dos, serDe) => {
|
||||
serDe.writeJObj(dos, batchDF)
|
||||
serDe.writeLong(dos, batchId)
|
||||
})
|
||||
}
|
||||
|
||||
object DotnetForeachBatchHelper {
|
||||
def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = {
|
||||
val dotnetForeachFunc = client match {
|
||||
case Some(value) => new DotnetForeachBatchFunction(value, callbackId)
|
||||
case None => throw new Exception("CallbackClient is null.")
|
||||
}
|
||||
|
||||
dsw.foreachBatch(dotnetForeachFunc.call _)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.api.dotnet
|
||||
|
||||
import java.util.{List => JList, Map => JMap}
|
||||
|
||||
import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction}
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
|
||||
object SQLUtils {
|
||||
|
||||
/**
|
||||
* Exposes createPythonFunction to the .NET client to enable registering UDFs.
|
||||
*/
|
||||
def createPythonFunction(
|
||||
command: Array[Byte],
|
||||
envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
pythonExec: String,
|
||||
pythonVersion: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: PythonAccumulatorV2): PythonFunction = {
|
||||
|
||||
PythonFunction(
|
||||
command,
|
||||
envVars,
|
||||
pythonIncludes,
|
||||
pythonExec,
|
||||
pythonVersion,
|
||||
broadcastVars,
|
||||
accumulator)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.test
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
|
||||
object TestUtils {
|
||||
|
||||
/**
|
||||
* Helper method to create typed MemoryStreams intended for use in unit tests.
|
||||
* @param sqlContext The SQLContext.
|
||||
* @param streamType The type of memory stream to create. This string is the `Name`
|
||||
* property of the dotnet type.
|
||||
* @return A typed MemoryStream.
|
||||
*/
|
||||
def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = {
|
||||
import sqlContext.implicits._
|
||||
|
||||
streamType match {
|
||||
case "Int32" => MemoryStream[Int]
|
||||
case "String" => MemoryStream[String]
|
||||
case _ => throw new Exception(s"$streamType not supported")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.dotnet
|
||||
|
||||
import java.io._
|
||||
import java.nio.file.attribute.PosixFilePermission
|
||||
import java.nio.file.attribute.PosixFilePermission._
|
||||
import java.nio.file.{FileSystems, Files}
|
||||
import java.util.{Timer, TimerTask}
|
||||
|
||||
import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile}
|
||||
import org.apache.commons.io.{FileUtils, IOUtils}
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.Set
|
||||
|
||||
/**
|
||||
* Utility methods.
|
||||
*/
|
||||
object Utils extends Logging {
|
||||
private val posixFilePermissions = Array(
|
||||
OWNER_READ,
|
||||
OWNER_WRITE,
|
||||
OWNER_EXECUTE,
|
||||
GROUP_READ,
|
||||
GROUP_WRITE,
|
||||
GROUP_EXECUTE,
|
||||
OTHERS_READ,
|
||||
OTHERS_WRITE,
|
||||
OTHERS_EXECUTE)
|
||||
|
||||
val supportPosix: Boolean =
|
||||
FileSystems.getDefault.supportedFileAttributeViews().contains("posix")
|
||||
|
||||
/**
|
||||
* Compress all files under given directory into one zip file and drop it to the target directory
|
||||
*
|
||||
* @param sourceDir source directory to zip
|
||||
* @param targetZipFile target zip file
|
||||
*/
|
||||
def zip(sourceDir: File, targetZipFile: File): Unit = {
|
||||
var fos: FileOutputStream = null
|
||||
var zos: ZipArchiveOutputStream = null
|
||||
try {
|
||||
fos = new FileOutputStream(targetZipFile)
|
||||
zos = new ZipArchiveOutputStream(fos)
|
||||
|
||||
val sourcePath = sourceDir.toPath
|
||||
FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file =>
|
||||
var in: FileInputStream = null
|
||||
try {
|
||||
val path = file.toPath
|
||||
val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString)
|
||||
if (supportPosix) {
|
||||
entry.setUnixMode(
|
||||
permissionsToMode(Files.getPosixFilePermissions(path).asScala)
|
||||
| (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4))
|
||||
} else if (entry.getName.endsWith(".exe")) {
|
||||
entry.setUnixMode(0x1ED) // 755
|
||||
} else {
|
||||
entry.setUnixMode(0x1A4) // 644
|
||||
}
|
||||
zos.putArchiveEntry(entry)
|
||||
|
||||
in = new FileInputStream(file)
|
||||
IOUtils.copy(in, zos)
|
||||
zos.closeArchiveEntry()
|
||||
} finally {
|
||||
IOUtils.closeQuietly(in)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
IOUtils.closeQuietly(zos)
|
||||
IOUtils.closeQuietly(fos)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unzip a file to the given directory
|
||||
*
|
||||
* @param file file to be unzipped
|
||||
* @param targetDir target directory
|
||||
*/
|
||||
def unzip(file: File, targetDir: File): Unit = {
|
||||
var zipFile: ZipFile = null
|
||||
try {
|
||||
targetDir.mkdirs()
|
||||
zipFile = new ZipFile(file)
|
||||
zipFile.getEntries.asScala.foreach { entry =>
|
||||
val targetFile = new File(targetDir, entry.getName)
|
||||
|
||||
if (targetFile.exists()) {
|
||||
logWarning(
|
||||
s"Target file/directory $targetFile already exists. Skip it for now. " +
|
||||
s"Make sure this is expected.")
|
||||
} else {
|
||||
if (entry.isDirectory) {
|
||||
targetFile.mkdirs()
|
||||
} else {
|
||||
targetFile.getParentFile.mkdirs()
|
||||
val input = zipFile.getInputStream(entry)
|
||||
val output = new FileOutputStream(targetFile)
|
||||
IOUtils.copy(input, output)
|
||||
IOUtils.closeQuietly(input)
|
||||
IOUtils.closeQuietly(output)
|
||||
if (supportPosix) {
|
||||
val permissions = modeToPermissions(entry.getUnixMode)
|
||||
// When run in Unix system, permissions will be empty, thus skip
|
||||
// setting the empty permissions (which will empty the previous permissions).
|
||||
if (permissions.nonEmpty) {
|
||||
Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => logError("exception caught during decompression:" + e)
|
||||
} finally {
|
||||
ZipFile.closeQuietly(zipFile)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Exits the JVM, trying to do it nicely, otherwise doing it nastily.
|
||||
*
|
||||
* @param status the exit status, zero for OK, non-zero for error
|
||||
* @param maxDelayMillis the maximum delay in milliseconds
|
||||
*/
|
||||
def exit(status: Int, maxDelayMillis: Long) {
|
||||
try {
|
||||
logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis")
|
||||
|
||||
// setup a timer, so if nice exit fails, the nasty exit happens
|
||||
val timer = new Timer()
|
||||
timer.schedule(new TimerTask() {
|
||||
|
||||
override def run() {
|
||||
Runtime.getRuntime.halt(status)
|
||||
}
|
||||
}, maxDelayMillis)
|
||||
// try to exit nicely
|
||||
System.exit(status);
|
||||
} catch {
|
||||
// exit nastily if we have a problem
|
||||
case _: Throwable => Runtime.getRuntime.halt(status)
|
||||
} finally {
|
||||
// should never get here
|
||||
Runtime.getRuntime.halt(status)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Exits the JVM, trying to do it nicely, wait 1 second
|
||||
*
|
||||
* @param status the exit status, zero for OK, non-zero for error
|
||||
*/
|
||||
def exit(status: Int): Unit = {
|
||||
exit(status, 1000)
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize the Spark version by taking the first three numbers.
|
||||
* For example:
|
||||
* x.y.z => x.y.z
|
||||
* x.y.z.xxx.yyy => x.y.z
|
||||
* x.y => x.y
|
||||
*
|
||||
* @param version the Spark version to normalize
|
||||
* @return Normalized Spark version.
|
||||
*/
|
||||
def normalizeSparkVersion(version: String): String = {
|
||||
version
|
||||
.split('.')
|
||||
.take(3)
|
||||
.zipWithIndex
|
||||
.map({
|
||||
case (element, index) => {
|
||||
index match {
|
||||
case 2 => element.split("\\D+").lift(0).getOrElse("")
|
||||
case _ => element
|
||||
}
|
||||
}
|
||||
})
|
||||
.mkString(".")
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the normalized spark version by verifying:
|
||||
* - Spark version starts with sparkMajorMinorVersionPrefix.
|
||||
* - If ignoreSparkPatchVersion is
|
||||
* - true: valid
|
||||
* - false: check if the spark version is in supportedSparkVersions.
|
||||
* @param ignoreSparkPatchVersion Ignore spark patch version.
|
||||
* @param sparkVersion The spark version.
|
||||
* @param normalizedSparkVersion: The normalized spark version.
|
||||
* @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against.
|
||||
* @param supportedSparkVersions The set of supported spark versions.
|
||||
*/
|
||||
def validateSparkVersions(
|
||||
ignoreSparkPatchVersion: Boolean,
|
||||
sparkVersion: String,
|
||||
normalizedSparkVersion: String,
|
||||
supportedSparkMajorMinorVersionPrefix: String,
|
||||
supportedSparkVersions: Set[String]): Unit = {
|
||||
if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) {
|
||||
throw new IllegalArgumentException(
|
||||
s"Unsupported spark version used: '$sparkVersion'. " +
|
||||
s"Normalized spark version used: '$normalizedSparkVersion'. " +
|
||||
s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.")
|
||||
} else if (ignoreSparkPatchVersion) {
|
||||
logWarning(
|
||||
s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " +
|
||||
s"Normalized spark version used: '$normalizedSparkVersion'. " +
|
||||
s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.")
|
||||
} else if (!supportedSparkVersions(normalizedSparkVersion)) {
|
||||
val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ")
|
||||
throw new IllegalArgumentException(
|
||||
s"Unsupported spark version used: '$sparkVersion'. " +
|
||||
s"Normalized spark version used: '$normalizedSparkVersion'. " +
|
||||
s"Supported versions: '$supportedVersions'.")
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] def listZipFileEntries(file: File): Array[String] = {
|
||||
var zipFile: ZipFile = null
|
||||
try {
|
||||
zipFile = new ZipFile(file)
|
||||
zipFile.getEntries.asScala.map(_.getName).toArray
|
||||
} finally {
|
||||
ZipFile.closeQuietly(zipFile)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = {
|
||||
posixFilePermissions.foldLeft(0) { (mode, perm) =>
|
||||
(mode << 1) | (if (permissions.contains(perm)) 1 else 0)
|
||||
}
|
||||
}
|
||||
|
||||
private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = {
|
||||
posixFilePermissions.zipWithIndex
|
||||
.filter { case (_, i) => (mode & (0x100 >>> i)) != 0 }
|
||||
.map(_._1)
|
||||
.toSet
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import Extensions._
|
||||
import org.junit.Assert._
|
||||
import org.junit.{After, Before, Test}
|
||||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
|
||||
@Test
|
||||
class DotnetBackendHandlerTest {
|
||||
private var backend: DotnetBackend = _
|
||||
private var tracker: JVMObjectTracker = _
|
||||
private var handler: DotnetBackendHandler = _
|
||||
|
||||
@Before
|
||||
def before(): Unit = {
|
||||
backend = new DotnetBackend
|
||||
tracker = new JVMObjectTracker
|
||||
handler = new DotnetBackendHandler(backend, tracker)
|
||||
}
|
||||
|
||||
@After
|
||||
def after(): Unit = {
|
||||
backend.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = {
|
||||
val message = givenMessage(m => {
|
||||
val serDe = new SerDe(null)
|
||||
m.writeBoolean(true) // static method
|
||||
serDe.writeInt(m, 1) // processId
|
||||
serDe.writeInt(m, 1) // threadId
|
||||
serDe.writeString(m, "DotnetHandler") // class name
|
||||
serDe.writeString(m, "connectCallback") // command (method) name
|
||||
m.writeInt(2) // number of arguments
|
||||
m.writeByte('c') // 1st argument type (string)
|
||||
serDe.writeString(m, "127.0.0.1") // 1st argument value (host)
|
||||
m.writeByte('i') // 2nd argument type (integer)
|
||||
m.writeInt(0) // 2nd argument value (port)
|
||||
})
|
||||
|
||||
val payload = handler.handleBackendRequest(message)
|
||||
val reply = new DataInputStream(new ByteArrayInputStream(payload))
|
||||
|
||||
assertEquals(
|
||||
"status code must be successful.", 0, reply.readInt())
|
||||
assertEquals('j', reply.readByte())
|
||||
assertEquals(1, reply.readInt())
|
||||
val trackingId = new String(reply.readNBytes(1), "UTF-8")
|
||||
assertEquals("1", trackingId)
|
||||
val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull
|
||||
assertEquals(classOf[CallbackClient], client.getClass)
|
||||
}
|
||||
|
||||
private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = {
|
||||
val buffer = new ByteArrayOutputStream()
|
||||
func(new DataOutputStream(buffer))
|
||||
buffer.toByteArray
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import org.junit.Assert._
|
||||
import org.junit.{After, Before, Test}
|
||||
|
||||
import java.net.InetAddress
|
||||
|
||||
@Test
|
||||
class DotnetBackendTest {
|
||||
private var backend: DotnetBackend = _
|
||||
|
||||
@Before
|
||||
def before(): Unit = {
|
||||
backend = new DotnetBackend
|
||||
}
|
||||
|
||||
@After
|
||||
def after(): Unit = {
|
||||
backend.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldNotResetCallbackClient(): Unit = {
|
||||
// Specifying port = 0 to select port dynamically.
|
||||
backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
|
||||
|
||||
assertTrue(backend.callbackClient.isDefined)
|
||||
assertThrows(classOf[Exception], () => {
|
||||
backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import java.io.DataInputStream
|
||||
|
||||
private[dotnet] object Extensions {
|
||||
implicit class DataInputStreamExt(stream: DataInputStream) {
|
||||
def readNBytes(n: Int): Array[Byte] = {
|
||||
val buf = new Array[Byte](n)
|
||||
stream.readFully(buf)
|
||||
buf
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import org.junit.Test
|
||||
|
||||
@Test
|
||||
class JVMObjectTrackerTest {
|
||||
|
||||
@Test
|
||||
def shouldReleaseAllReferences(): Unit = {
|
||||
val tracker = new JVMObjectTracker
|
||||
val firstId = tracker.put(new Object)
|
||||
val secondId = tracker.put(new Object)
|
||||
val thirdId = tracker.put(new Object)
|
||||
|
||||
tracker.clear()
|
||||
|
||||
assert(tracker.get(firstId).isEmpty)
|
||||
assert(tracker.get(secondId).isEmpty)
|
||||
assert(tracker.get(thirdId).isEmpty)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldResetCounter(): Unit = {
|
||||
val tracker = new JVMObjectTracker
|
||||
val firstId = tracker.put(new Object)
|
||||
val secondId = tracker.put(new Object)
|
||||
|
||||
tracker.clear()
|
||||
|
||||
val thirdId = tracker.put(new Object)
|
||||
|
||||
assert(firstId.equals("1"))
|
||||
assert(secondId.equals("2"))
|
||||
assert(thirdId.equals("1"))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,373 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.dotnet
|
||||
|
||||
import org.apache.spark.api.dotnet.Extensions._
|
||||
import org.apache.spark.sql.Row
|
||||
import org.junit.Assert._
|
||||
import org.junit.{Before, Test}
|
||||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
import java.sql.Date
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@Test
|
||||
class SerDeTest {
|
||||
private var serDe: SerDe = _
|
||||
private var tracker: JVMObjectTracker = _
|
||||
|
||||
@Before
|
||||
def before(): Unit = {
|
||||
tracker = new JVMObjectTracker
|
||||
serDe = new SerDe(tracker)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadNull(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('n')
|
||||
})
|
||||
|
||||
assertEquals(null, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldThrowForUnsupportedTypes(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('_')
|
||||
})
|
||||
|
||||
assertThrows(classOf[IllegalArgumentException], () => {
|
||||
serDe.readObject(input)
|
||||
})
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadInteger(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('i')
|
||||
in.writeInt(42)
|
||||
})
|
||||
|
||||
assertEquals(42, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadLong(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('g')
|
||||
in.writeLong(42)
|
||||
})
|
||||
|
||||
assertEquals(42L, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadDouble(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('d')
|
||||
in.writeDouble(42.42)
|
||||
})
|
||||
|
||||
assertEquals(42.42, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadBoolean(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('b')
|
||||
in.writeBoolean(true)
|
||||
})
|
||||
|
||||
assertEquals(true, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadString(): Unit = {
|
||||
val payload = "Spark Dotnet"
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('c')
|
||||
in.writeInt(payload.getBytes("UTF-8").length)
|
||||
in.write(payload.getBytes("UTF-8"))
|
||||
})
|
||||
|
||||
assertEquals(payload, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadMap(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('e') // map type descriptor
|
||||
in.writeInt(3) // size
|
||||
in.writeByte('i') // key type
|
||||
in.writeInt(3) // number of keys
|
||||
in.writeInt(11) // first key
|
||||
in.writeInt(22) // second key
|
||||
in.writeInt(33) // third key
|
||||
in.writeInt(3) // number of values
|
||||
in.writeByte('b') // first value type
|
||||
in.writeBoolean(true) // first value
|
||||
in.writeByte('d') // second value type
|
||||
in.writeDouble(42.42) // second value
|
||||
in.writeByte('n') // third type & value
|
||||
})
|
||||
|
||||
assertEquals(
|
||||
mapAsJavaMap(Map(
|
||||
11 -> true,
|
||||
22 -> 42.42,
|
||||
33 -> null)),
|
||||
serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadEmptyMap(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('e') // map type descriptor
|
||||
in.writeInt(0) // size
|
||||
})
|
||||
|
||||
assertEquals(mapAsJavaMap(Map()), serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadBytesArray(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('r') // byte array type descriptor
|
||||
in.writeInt(3) // length
|
||||
in.write(Array[Byte](1, 2, 3)) // payload
|
||||
})
|
||||
|
||||
assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]])
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadEmptyBytesArray(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('r') // byte array type descriptor
|
||||
in.writeInt(0) // length
|
||||
})
|
||||
|
||||
assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]])
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadEmptyList(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('l') // type descriptor
|
||||
in.writeByte('i') // element type
|
||||
in.writeInt(0) // length
|
||||
})
|
||||
|
||||
assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]])
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadList(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('l') // type descriptor
|
||||
in.writeByte('b') // element type
|
||||
in.writeInt(3) // length
|
||||
in.writeBoolean(true)
|
||||
in.writeBoolean(false)
|
||||
in.writeBoolean(true)
|
||||
})
|
||||
|
||||
assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]])
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldThrowWhenReadingListWithUnsupportedType(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('l') // type descriptor
|
||||
in.writeByte('_') // unsupported element type
|
||||
})
|
||||
|
||||
assertThrows(classOf[IllegalArgumentException], () => {
|
||||
serDe.readObject(input)
|
||||
})
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadDate(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
val date = "2020-12-31"
|
||||
in.writeByte('D') // type descriptor
|
||||
in.writeInt(date.getBytes("UTF-8").length) // date string size
|
||||
in.write(date.getBytes("UTF-8"))
|
||||
})
|
||||
|
||||
assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadObject(): Unit = {
|
||||
val trackingObject = new Object
|
||||
tracker.put(trackingObject)
|
||||
val input = givenInput(in => {
|
||||
val objectIndex = "1"
|
||||
in.writeByte('j') // type descriptor
|
||||
in.writeInt(objectIndex.getBytes("UTF-8").length) // size
|
||||
in.write(objectIndex.getBytes("UTF-8"))
|
||||
})
|
||||
|
||||
assertSame(trackingObject, serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldThrowWhenReadingNonTrackingObject(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
val objectIndex = "42"
|
||||
in.writeByte('j') // type descriptor
|
||||
in.writeInt(objectIndex.getBytes("UTF-8").length) // size
|
||||
in.write(objectIndex.getBytes("UTF-8"))
|
||||
})
|
||||
|
||||
assertThrows(classOf[NoSuchElementException], () => {
|
||||
serDe.readObject(input)
|
||||
})
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadSparkRows(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('R') // type descriptor
|
||||
in.writeInt(2) // number of rows
|
||||
in.writeInt(1) // number of elements in 1st row
|
||||
in.writeByte('i') // type of 1st element in 1st row
|
||||
in.writeInt(11)
|
||||
in.writeInt(3) // number of elements in 2st row
|
||||
in.writeByte('b') // type of 1st element in 2nd row
|
||||
in.writeBoolean(true)
|
||||
in.writeByte('d') // type of 2nd element in 2nd row
|
||||
in.writeDouble(42.24)
|
||||
in.writeByte('g') // type of 3nd element in 2nd row
|
||||
in.writeLong(99)
|
||||
})
|
||||
|
||||
assertEquals(
|
||||
seqAsJavaList(Seq(
|
||||
Row.fromSeq(Seq(11)),
|
||||
Row.fromSeq(Seq(true, 42.24, 99)))),
|
||||
serDe.readObject(input))
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldReadArrayOfObjects(): Unit = {
|
||||
val input = givenInput(in => {
|
||||
in.writeByte('O') // type descriptor
|
||||
in.writeInt(2) // number of elements
|
||||
in.writeByte('i') // type of 1st element
|
||||
in.writeInt(42)
|
||||
in.writeByte('b') // type of 2nd element
|
||||
in.writeBoolean(true)
|
||||
})
|
||||
|
||||
assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]])
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWriteNull(): Unit = {
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, null)
|
||||
serDe.writeObject(out, Unit)
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'n')
|
||||
assertEquals(in.readByte(), 'n')
|
||||
assertEndOfStream(in)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWriteString(): Unit = {
|
||||
val sparkDotnet = "Spark Dotnet"
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, sparkDotnet)
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'c') // object type
|
||||
assertEquals(in.readInt(), sparkDotnet.length) // length
|
||||
assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8"))
|
||||
assertEndOfStream(in)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWritePrimitiveTypes(): Unit = {
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, 42.24f.asInstanceOf[Object])
|
||||
serDe.writeObject(out, 42L.asInstanceOf[Object])
|
||||
serDe.writeObject(out, 42.asInstanceOf[Object])
|
||||
serDe.writeObject(out, true.asInstanceOf[Object])
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'd')
|
||||
assertEquals(in.readDouble(), 42.24F, 0.000001)
|
||||
assertEquals(in.readByte(), 'g')
|
||||
assertEquals(in.readLong(), 42L)
|
||||
assertEquals(in.readByte(), 'i')
|
||||
assertEquals(in.readInt(), 42)
|
||||
assertEquals(in.readByte(), 'b')
|
||||
assertEquals(in.readBoolean(), true)
|
||||
assertEndOfStream(in)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWriteDate(): Unit = {
|
||||
val date = "2020-12-31"
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, Date.valueOf(date))
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'D') // type
|
||||
assertEquals(in.readInt(), 10) // size
|
||||
assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWriteCustomObjects(): Unit = {
|
||||
val customObject = new Object
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, customObject)
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'j')
|
||||
assertEquals(in.readInt(), 1)
|
||||
assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8"))
|
||||
assertSame(tracker.get("1").get, customObject)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldWriteArrayOfCustomObjects(): Unit = {
|
||||
val payload = Array(new Object, new Object)
|
||||
val in = whenOutput(out => {
|
||||
serDe.writeObject(out, payload)
|
||||
})
|
||||
|
||||
assertEquals(in.readByte(), 'l') // array type
|
||||
assertEquals(in.readByte(), 'j') // type of element in array
|
||||
assertEquals(in.readInt(), 2) // array length
|
||||
assertEquals(in.readInt(), 1) // size of 1st element's identifiers
|
||||
assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element
|
||||
assertEquals(in.readInt(), 1) // size of 2nd element's identifier
|
||||
assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element
|
||||
assertSame(tracker.get("1").get, payload(0))
|
||||
assertSame(tracker.get("2").get, payload(1))
|
||||
}
|
||||
|
||||
private def givenInput(func: DataOutputStream => Unit): DataInputStream = {
|
||||
val buffer = new ByteArrayOutputStream()
|
||||
val out = new DataOutputStream(buffer)
|
||||
func(out)
|
||||
new DataInputStream(new ByteArrayInputStream(buffer.toByteArray))
|
||||
}
|
||||
|
||||
private def whenOutput = givenInput _
|
||||
|
||||
private def assertEndOfStream (in: DataInputStream): Unit = {
|
||||
assertEquals(-1, in.read())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Licensed to the .NET Foundation under one or more agreements.
|
||||
* The .NET Foundation licenses this file to you under the MIT license.
|
||||
* See the LICENSE file in the project root for more information.
|
||||
*/
|
||||
|
||||
package org.apache.spark.util.dotnet
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
|
||||
import org.junit.Assert.{assertEquals, assertThrows}
|
||||
import org.junit.Test
|
||||
|
||||
@Test
|
||||
class UtilsTest {
|
||||
|
||||
@Test
|
||||
def shouldIgnorePatchVersion(): Unit = {
|
||||
val sparkVersion = "3.2.1"
|
||||
val sparkMajorMinorVersionPrefix = "3.2"
|
||||
val supportedSparkVersions = Set[String]("3.2.0")
|
||||
|
||||
Utils.validateSparkVersions(
|
||||
true,
|
||||
sparkVersion,
|
||||
Utils.normalizeSparkVersion(sparkVersion),
|
||||
sparkMajorMinorVersionPrefix,
|
||||
supportedSparkVersions)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldThrowForUnsupportedVersion(): Unit = {
|
||||
val sparkVersion = "3.2.1"
|
||||
val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
|
||||
val sparkMajorMinorVersionPrefix = "3.2"
|
||||
val supportedSparkVersions = Set[String]("3.2.0")
|
||||
|
||||
val exception = assertThrows(
|
||||
classOf[IllegalArgumentException],
|
||||
() => {
|
||||
Utils.validateSparkVersions(
|
||||
false,
|
||||
sparkVersion,
|
||||
normalizedSparkVersion,
|
||||
sparkMajorMinorVersionPrefix,
|
||||
supportedSparkVersions)
|
||||
})
|
||||
|
||||
assertEquals(
|
||||
s"Unsupported spark version used: '$sparkVersion'. " +
|
||||
s"Normalized spark version used: '$normalizedSparkVersion'. " +
|
||||
s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.",
|
||||
exception.getMessage)
|
||||
}
|
||||
|
||||
@Test
|
||||
def shouldThrowForUnsupportedMajorMinorVersion(): Unit = {
|
||||
val sparkVersion = "2.4.4"
|
||||
val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
|
||||
val sparkMajorMinorVersionPrefix = "3.2"
|
||||
val supportedSparkVersions = Set[String]("3.2.0")
|
||||
|
||||
val exception = assertThrows(
|
||||
classOf[IllegalArgumentException],
|
||||
() => {
|
||||
Utils.validateSparkVersions(
|
||||
false,
|
||||
sparkVersion,
|
||||
normalizedSparkVersion,
|
||||
sparkMajorMinorVersionPrefix,
|
||||
supportedSparkVersions)
|
||||
})
|
||||
|
||||
assertEquals(
|
||||
s"Unsupported spark version used: '$sparkVersion'. " +
|
||||
s"Normalized spark version used: '$normalizedSparkVersion'. " +
|
||||
s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.",
|
||||
exception.getMessage)
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@
|
|||
<module>microsoft-spark-2-4</module>
|
||||
<module>microsoft-spark-3-0</module>
|
||||
<module>microsoft-spark-3-1</module>
|
||||
<module>microsoft-spark-3-2</module>
|
||||
</modules>
|
||||
|
||||
<pluginRepositories>
|
||||
|
|
Загрузка…
Ссылка в новой задаче