This commit is contained in:
Niharika Dutta 2022-01-27 11:13:13 -08:00 коммит произвёл GitHub
Родитель 6f82f15c22
Коммит 7bc016f5ed
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
24 изменённых файлов: 2579 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,77 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.microsoft.scala</groupId>
<artifactId>microsoft-spark</artifactId>
<version>${microsoft-spark.version}</version>
</parent>
<artifactId>microsoft-spark-3-2_2.12</artifactId>
<inceptionYear>2019</inceptionYear>
<properties>
<encoding>UTF-8</encoding>
<scala.version>2.12.10</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<spark.version>3.2.0</spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.specs</groupId>
<artifactId>specs</artifactId>
<version>1.2.5</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<testSourceDirectory>src/test/scala</testSourceDirectory>
<plugins>
<plugin>
<groupId>org.scala-tools</groupId>
<artifactId>maven-scala-plugin</artifactId>
<version>2.15.2</version>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<args>
<arg>-target:jvm-1.8</arg>
<arg>-deprecation</arg>
<arg>-feature</arg>
</args>
</configuration>
</plugin>
</plugins>
</build>
</project>

Просмотреть файл

@ -0,0 +1,72 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.io.DataOutputStream
import org.apache.spark.internal.Logging
import scala.collection.mutable.Queue
/**
* CallbackClient is used to communicate with the Dotnet CallbackServer.
* The client manages and maintains a pool of open CallbackConnections.
* Any callback request is delegated to a new CallbackConnection or
* unused CallbackConnection.
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]()
private[this] var isShutdown: Boolean = false
final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit =
getOrCreateConnection() match {
case Some(connection) =>
try {
connection.send(callbackId, writeBody)
addConnection(connection)
} catch {
case e: Exception =>
logError(s"Error calling callback [callback id = $callbackId].", e)
connection.close()
throw e
}
case None => throw new Exception("Unable to get or create connection.")
}
private def getOrCreateConnection(): Option[CallbackConnection] = synchronized {
if (isShutdown) {
logInfo("Cannot get or create connection while client is shutdown.")
return None
}
if (connectionPool.nonEmpty) {
return Some(connectionPool.dequeue())
}
Some(new CallbackConnection(serDe, address, port))
}
private def addConnection(connection: CallbackConnection): Unit = synchronized {
assert(connection != null)
connectionPool.enqueue(connection)
}
def shutdown(): Unit = synchronized {
if (isShutdown) {
logInfo("Shutdown called, but already shutdown.")
return
}
logInfo("Shutting down.")
connectionPool.foreach(_.close)
connectionPool.clear
isShutdown = true
}
}

Просмотреть файл

@ -0,0 +1,112 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream}
import java.net.Socket
import org.apache.spark.internal.Logging
/**
* CallbackConnection is used to process the callback communication
* between the JVM and Dotnet. It uses a TCP socket to communicate with
* the Dotnet CallbackServer and the socket is expected to be reused.
* @param address The address of the Dotnet CallbackServer
* @param port The port of the Dotnet CallbackServer
*/
class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging {
private[this] val socket: Socket = new Socket(address, port)
private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream)
private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream)
def send(
callbackId: Int,
writeBody: (DataOutputStream, SerDe) => Unit): Unit = {
logInfo(s"Calling callback [callback id = $callbackId] ...")
try {
serDe.writeInt(outputStream, CallbackFlags.CALLBACK)
serDe.writeInt(outputStream, callbackId)
val byteArrayOutputStream = new ByteArrayOutputStream()
writeBody(new DataOutputStream(byteArrayOutputStream), serDe)
serDe.writeInt(outputStream, byteArrayOutputStream.size)
byteArrayOutputStream.writeTo(outputStream);
} catch {
case e: Exception => {
throw new Exception("Error writing to stream.", e)
}
}
logInfo(s"Signaling END_OF_STREAM.")
try {
serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM)
outputStream.flush()
val endOfStreamResponse = readFlag(inputStream)
endOfStreamResponse match {
case CallbackFlags.END_OF_STREAM =>
logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.")
case _ => {
throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " +
s"Received: $endOfStreamResponse")
}
}
} catch {
case e: Exception => {
throw new Exception("Error while verifying end of stream.", e)
}
}
}
def close(): Unit = {
try {
serDe.writeInt(outputStream, CallbackFlags.CLOSE)
outputStream.flush()
} catch {
case e: Exception => logInfo("Unable to send close to .NET callback server.", e)
}
close(socket)
close(outputStream)
close(inputStream)
}
private def close(s: Socket): Unit = {
try {
assert(s != null)
s.close()
} catch {
case e: Exception => logInfo("Unable to close socket.", e)
}
}
private def close(c: Closeable): Unit = {
try {
assert(c != null)
c.close()
} catch {
case e: Exception => logInfo("Unable to close closeable.", e)
}
}
private def readFlag(inputStream: DataInputStream): Int = {
val callbackFlag = serDe.readInt(inputStream)
if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) {
val exceptionMessage = serDe.readString(inputStream)
throw new DotnetException(exceptionMessage)
}
callbackFlag
}
private object CallbackFlags {
val CLOSE: Int = -1
val CALLBACK: Int = -2
val DOTNET_EXCEPTION_THROWN: Int = -3
val END_OF_STREAM: Int = -4
}
}

Просмотреть файл

@ -0,0 +1,113 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.net.InetSocketAddress
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS
import org.apache.spark.{SparkConf, SparkEnv}
/**
* Netty server that invokes JVM calls based upon receiving messages from .NET.
* The implementation mirrors the RBackend.
*
*/
class DotnetBackend extends Logging {
self => // for accessing the this reference in inner class(ChannelInitializer)
private[this] var channelFuture: ChannelFuture = _
private[this] var bootstrap: ServerBootstrap = _
private[this] var bossGroup: EventLoopGroup = _
private[this] val objectTracker = new JVMObjectTracker
@volatile
private[dotnet] var callbackClient: Option[CallbackClient] = None
def init(portNumber: Int): Int = {
val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS)
logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.")
bossGroup = new NioEventLoopGroup(numBackendThreads)
val workerGroup = bossGroup
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
.channel(classOf[NioServerSocketChannel])
bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
def initChannel(ch: SocketChannel): Unit = {
ch.pipeline()
.addLast("encoder", new ByteArrayEncoder())
.addLast(
"frameDecoder",
// maxFrameLength = 2G
// lengthFieldOffset = 0
// lengthFieldLength = 4
// lengthAdjustment = 0
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("handler", new DotnetBackendHandler(self, objectTracker))
}
})
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber))
channelFuture.syncUninterruptibly()
channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort
}
private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized {
callbackClient = callbackClient match {
case Some(_) => throw new Exception("Callback client already set.")
case None =>
logInfo(s"Connecting to a callback server at $address:$port")
Some(new CallbackClient(new SerDe(objectTracker), address, port))
}
}
private[dotnet] def shutdownCallbackClient(): Unit = synchronized {
callbackClient match {
case Some(client) => client.shutdown()
case None => logInfo("Callback server has already been shutdown.")
}
callbackClient = None
}
def run(): Unit = {
channelFuture.channel.closeFuture().syncUninterruptibly()
}
def close(): Unit = {
if (channelFuture != null) {
// close is a local operation and should finish within milliseconds; timeout just to be safe
channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
channelFuture = null
}
if (bootstrap != null && bootstrap.config().group() != null) {
bootstrap.config().group().shutdownGracefully()
}
if (bootstrap != null && bootstrap.config().childGroup() != null) {
bootstrap.config().childGroup().shutdownGracefully()
}
bootstrap = null
objectTracker.clear()
// Send close to .NET callback server.
shutdownCallbackClient()
// Shutdown the thread pool whose executors could still be running.
ThreadPool.shutdown()
}
}

Просмотреть файл

@ -0,0 +1,335 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import scala.collection.mutable.HashMap
import scala.language.existentials
/**
* Handler for DotnetBackend.
* This implementation is similar to RBackendHandler.
*/
class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker)
extends SimpleChannelInboundHandler[Array[Byte]]
with Logging {
private[this] val serDe = new SerDe(objectsTracker)
override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
val reply = handleBackendRequest(msg)
ctx.write(reply)
}
override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
ctx.flush()
}
def handleBackendRequest(msg: Array[Byte]): Array[Byte] = {
val bis = new ByteArrayInputStream(msg)
val dis = new DataInputStream(bis)
val bos = new ByteArrayOutputStream()
val dos = new DataOutputStream(bos)
// First bit is isStatic
val isStatic = serDe.readBoolean(dis)
val processId = serDe.readInt(dis)
val threadId = serDe.readInt(dis)
val objId = serDe.readString(dis)
val methodName = serDe.readString(dis)
val numArgs = serDe.readInt(dis)
if (objId == "DotnetHandler") {
methodName match {
case "stopBackend" =>
serDe.writeInt(dos, 0)
serDe.writeType(dos, "void")
server.close()
case "rm" =>
try {
val t = serDe.readObjectType(dis)
assert(t == 'c')
val objToRemove = serDe.readString(dis)
objectsTracker.remove(objToRemove)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, null)
} catch {
case e: Exception =>
logError(s"Removing $objId failed", e)
serDe.writeInt(dos, -1)
}
case "rmThread" =>
try {
assert(serDe.readObjectType(dis) == 'i')
val processId = serDe.readInt(dis)
assert(serDe.readObjectType(dis) == 'i')
val threadToDelete = serDe.readInt(dis)
val result = ThreadPool.tryDeleteThread(processId, threadToDelete)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, result.asInstanceOf[AnyRef])
} catch {
case e: Exception =>
logError(s"Removing thread $threadId failed", e)
serDe.writeInt(dos, -1)
}
case "connectCallback" =>
assert(serDe.readObjectType(dis) == 'c')
val address = serDe.readString(dis)
assert(serDe.readObjectType(dis) == 'i')
val port = serDe.readInt(dis)
server.setCallbackClient(address, port)
serDe.writeInt(dos, 0)
// Sends reference of CallbackClient to dotnet side,
// so that dotnet process can send the client back to Java side
// when calling any API containing callback functions.
serDe.writeObject(dos, server.callbackClient)
case "closeCallback" =>
logInfo("Requesting to close callback client")
server.shutdownCallbackClient()
serDe.writeInt(dos, 0)
serDe.writeType(dos, "void")
case _ => dos.writeInt(-1)
}
} else {
ThreadPool
.run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos))
}
bos.toByteArray
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
// Skip logging the exception message if the connection was disconnected from
// the .NET side so that .NET side doesn't have to explicitly close the connection via
// "stopBackend." Note that an exception is still thrown if the exit status is non-zero,
// so skipping this kind of exception message does not affect the debugging.
if (!cause.getMessage.contains(
"An existing connection was forcibly closed by the remote host")) {
logError("Exception caught: ", cause)
}
// Close the connection when an exception is raised.
ctx.close()
}
def handleMethodCall(
isStatic: Boolean,
objId: String,
methodName: String,
numArgs: Int,
dis: DataInputStream,
dos: DataOutputStream): Unit = {
var obj: Object = null
var args: Array[java.lang.Object] = null
var methods: Array[java.lang.reflect.Method] = null
try {
val cls = if (isStatic) {
Utils.classForName(objId)
} else {
objectsTracker.get(objId) match {
case None => throw new IllegalArgumentException("Object not found " + objId)
case Some(o) =>
obj = o
o.getClass
}
}
args = readArgs(numArgs, dis)
methods = cls.getMethods
val selectedMethods = methods.filter(m => m.getName == methodName)
if (selectedMethods.length > 0) {
val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args)
if (index.isEmpty) {
logWarning(
s"cannot find matching method ${cls}.$methodName. "
+ s"Candidates are:")
selectedMethods.foreach { method =>
logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
}
throw new Exception(s"No matched method found for $cls.$methodName")
}
val ret = selectedMethods(index.get).invoke(obj, args: _*)
// Write status bit
serDe.writeInt(dos, 0)
serDe.writeObject(dos, ret.asInstanceOf[AnyRef])
} else if (methodName == "<init>") {
// methodName should be "<init>" for constructor
val ctor = cls.getConstructors.filter { x =>
matchMethod(numArgs, args, x.getParameterTypes)
}.head
val obj = ctor.newInstance(args: _*)
serDe.writeInt(dos, 0)
serDe.writeObject(dos, obj.asInstanceOf[AnyRef])
} else {
throw new IllegalArgumentException(
"invalid method " + methodName + " for object " + objId)
}
} catch {
case e: Throwable =>
val jvmObj = objectsTracker.get(objId)
val jvmObjName = jvmObj match {
case Some(jObj) => jObj.getClass.getName
case None => "NullObject"
}
val argsStr = args
.map(arg => {
if (arg != null) {
s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]"
} else {
"[Value: NULL]"
}
})
.mkString(", ")
logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)")
if (methods != null) {
logDebug(s"All methods for $jvmObjName:")
methods.foreach(m => logDebug(m.toString))
}
serDe.writeInt(dos, -1)
serDe.writeString(dos, Utils.exceptionString(e.getCause))
}
}
// Read a number of arguments from the data input stream
def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
(0 until numArgs).map { arg =>
serDe.readObject(dis)
}.toArray
}
// Checks if the arguments passed in args matches the parameter types.
// NOTE: Currently we do exact match. We may add type conversions later.
def matchMethod(
numArgs: Int,
args: Array[java.lang.Object],
parameterTypes: Array[Class[_]]): Boolean = {
if (parameterTypes.length != numArgs) {
return false
}
for (i <- 0 until numArgs) {
val parameterType = parameterTypes(i)
var parameterWrapperType = parameterType
// Convert native parameters to Object types as args is Array[Object] here
if (parameterType.isPrimitive) {
parameterWrapperType = parameterType match {
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
case java.lang.Long.TYPE => classOf[java.lang.Long]
case java.lang.Double.TYPE => classOf[java.lang.Double]
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
case _ => parameterType
}
}
if (!parameterWrapperType.isInstance(args(i))) {
// non primitive types
if (!parameterType.isPrimitive && args(i) != null) {
return false
}
// primitive types
if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) {
return false
}
}
}
true
}
// Find a matching method signature in an array of signatures of constructors
// or methods of the same name according to the passed arguments. Arguments
// may be converted in order to match a signature.
//
// Note that in Java reflection, constructors and normal methods are of different
// classes, and share no parent class that provides methods for reflection uses.
// There is no unified way to handle them in this function. So an array of signatures
// is passed in instead of an array of candidate constructors or methods.
//
// Returns an Option[Int] which is the index of the matched signature in the array.
def findMatchedSignature(
parameterTypesOfMethods: Array[Array[Class[_]]],
args: Array[Object]): Option[Int] = {
val numArgs = args.length
for (index <- parameterTypesOfMethods.indices) {
val parameterTypes = parameterTypesOfMethods(index)
if (parameterTypes.length == numArgs) {
var argMatched = true
var i = 0
while (i < numArgs && argMatched) {
val parameterType = parameterTypes(i)
if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) {
// The case that the parameter type is a Scala Seq and the argument
// is a Java array is considered matching. The array will be converted
// to a Seq later if this method is matched.
} else {
var parameterWrapperType = parameterType
// Convert native parameters to Object types as args is Array[Object] here
if (parameterType.isPrimitive) {
parameterWrapperType = parameterType match {
case java.lang.Integer.TYPE => classOf[java.lang.Integer]
case java.lang.Long.TYPE => classOf[java.lang.Long]
case java.lang.Double.TYPE => classOf[java.lang.Double]
case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
case _ => parameterType
}
}
if ((parameterType.isPrimitive || args(i) != null) &&
!parameterWrapperType.isInstance(args(i))) {
argMatched = false
}
}
i = i + 1
}
if (argMatched) {
// For now, we return the first matching method.
// TODO: find best method in matching methods.
// Convert args if needed
val parameterTypes = parameterTypesOfMethods(index)
for (i <- 0 until numArgs) {
if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) {
// Convert a Java array to scala Seq
args(i) = args(i).asInstanceOf[Array[_]].toSeq
}
}
return Some(index)
}
}
}
None
}
def logError(id: String, e: Exception): Unit = {}
}

Просмотреть файл

@ -0,0 +1,13 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
class DotnetException(message: String, cause: Throwable)
extends Exception(message, cause) {
def this(message: String) = this(message, null)
}

Просмотреть файл

@ -0,0 +1,30 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
object DotnetRDD {
def createPythonRDD(
parent: RDD[_],
func: PythonFunction,
preservePartitoning: Boolean): PythonRDD = {
new PythonRDD(parent, func, preservePartitoning)
}
def createJavaRDDFromArray(
sc: SparkContext,
arr: Array[Array[Byte]],
numSlices: Int): JavaRDD[Array[Byte]] = {
JavaRDD.fromRDD(sc.parallelize(arr, numSlices))
}
def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd)
}

Просмотреть файл

@ -0,0 +1,55 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import scala.collection.mutable.HashMap
/**
* Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects.
*/
private[dotnet] class JVMObjectTracker {
// Multiple threads may access objMap and increase objCounter. Because get method return Option,
// it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap.
private[this] val objMap = new HashMap[String, Object]
private[this] var objCounter: Int = 1
def getObject(id: String): Object = {
synchronized {
objMap(id)
}
}
def get(id: String): Option[Object] = {
synchronized {
objMap.get(id)
}
}
def put(obj: Object): String = {
synchronized {
val objId = objCounter.toString
objCounter = objCounter + 1
objMap.put(objId, obj)
objId
}
}
def remove(id: String): Option[Object] = {
synchronized {
objMap.remove(id)
}
}
def clear(): Unit = {
synchronized {
objMap.clear()
objCounter = 1
}
}
}

Просмотреть файл

@ -0,0 +1,33 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.sql.api.dotnet
import org.apache.spark.SparkConf
/*
* Utils for JvmBridge.
*/
object JvmBridgeUtils {
def getKeyValuePairAsString(kvp: (String, String)): String = {
return kvp._1 + "=" + kvp._2
}
def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = {
val sb = new StringBuilder
for (kvp <- kvpArray) {
sb.append(getKeyValuePairAsString(kvp))
sb.append(";")
}
sb.toString
}
def getSparkConfAsString(sparkConf: SparkConf): String = {
getKeyValuePairArrayAsString(sparkConf.getAll)
}
}

Просмотреть файл

@ -0,0 +1,387 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.io.{DataInputStream, DataOutputStream}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Time, Timestamp}
import org.apache.spark.sql.Row
import scala.collection.JavaConverters._
/**
* Class responsible for serialization and deserialization between CLR & JVM.
* This implementation of methods is mostly identical to the SerDe implementation in R.
*/
class SerDe(val tracker: JVMObjectTracker) {
def readObjectType(dis: DataInputStream): Char = {
dis.readByte().toChar
}
def readObject(dis: DataInputStream): Object = {
val dataType = readObjectType(dis)
readTypedObject(dis, dataType)
}
private def readTypedObject(dis: DataInputStream, dataType: Char): Object = {
dataType match {
case 'n' => null
case 'i' => new java.lang.Integer(readInt(dis))
case 'g' => new java.lang.Long(readLong(dis))
case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis))
case 'c' => readString(dis)
case 'e' => readMap(dis)
case 'r' => readBytes(dis)
case 'l' => readList(dis)
case 'D' => readDate(dis)
case 't' => readTime(dis)
case 'j' => tracker.getObject(readString(dis))
case 'R' => readRowArr(dis)
case 'O' => readObjectArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
}
}
private def readBytes(in: DataInputStream): Array[Byte] = {
val len = readInt(in)
val out = new Array[Byte](len)
in.readFully(out)
out
}
def readInt(in: DataInputStream): Int = {
in.readInt()
}
private def readLong(in: DataInputStream): Long = {
in.readLong()
}
private def readDouble(in: DataInputStream): Double = {
in.readDouble()
}
private def readStringBytes(in: DataInputStream, len: Int): String = {
val bytes = new Array[Byte](len)
in.readFully(bytes)
val str = new String(bytes, "UTF-8")
str
}
def readString(in: DataInputStream): String = {
val len = in.readInt()
readStringBytes(in, len)
}
def readBoolean(in: DataInputStream): Boolean = {
in.readBoolean()
}
private def readDate(in: DataInputStream): Date = {
Date.valueOf(readString(in))
}
private def readTime(in: DataInputStream): Timestamp = {
val seconds = in.readDouble()
val sec = Math.floor(seconds).toLong
val t = new Timestamp(sec * 1000L)
t.setNanos(((seconds - sec) * 1e9).toInt)
t
}
private def readRow(in: DataInputStream): Row = {
val len = readInt(in)
Row.fromSeq((0 until len).map(_ => readObject(in)))
}
private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
val len = readInt(in)
(0 until len).map(_ => readBytes(in)).toArray
}
private def readIntArr(in: DataInputStream): Array[Int] = {
val len = readInt(in)
(0 until len).map(_ => readInt(in)).toArray
}
private def readLongArr(in: DataInputStream): Array[Long] = {
val len = readInt(in)
(0 until len).map(_ => readLong(in)).toArray
}
private def readDoubleArr(in: DataInputStream): Array[Double] = {
val len = readInt(in)
(0 until len).map(_ => readDouble(in)).toArray
}
private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = {
val len = readInt(in)
(0 until len).map(_ => readDoubleArr(in)).toArray
}
private def readBooleanArr(in: DataInputStream): Array[Boolean] = {
val len = readInt(in)
(0 until len).map(_ => readBoolean(in)).toArray
}
private def readStringArr(in: DataInputStream): Array[String] = {
val len = readInt(in)
(0 until len).map(_ => readString(in)).toArray
}
private def readRowArr(in: DataInputStream): java.util.List[Row] = {
val len = readInt(in)
(0 until len).map(_ => readRow(in)).toList.asJava
}
private def readObjectArr(in: DataInputStream): Seq[Any] = {
val len = readInt(in)
(0 until len).map(_ => readObject(in))
}
private def readList(dis: DataInputStream): Array[_] = {
val arrType = readObjectType(dis)
arrType match {
case 'i' => readIntArr(dis)
case 'g' => readLongArr(dis)
case 'c' => readStringArr(dis)
case 'd' => readDoubleArr(dis)
case 'A' => readDoubleArrArr(dis)
case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => tracker.getObject(x))
case 'r' => readBytesArr(dis)
case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
}
}
private def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
val len = readInt(in)
if (len > 0) {
val keysType = readObjectType(in)
val keysLen = readInt(in)
val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
val valuesLen = readInt(in)
val values = (0 until valuesLen).map(_ => {
val valueType = readObjectType(in)
readTypedObject(in, valueType)
})
keys.zip(values).toMap.asJava
} else {
new java.util.HashMap[Object, Object]()
}
}
// Using the same mapping as SparkR implementation for now
// Methods to write out data from Java to .NET.
//
// Type mapping from Java to .NET:
//
// void -> NULL
// Int -> integer
// String -> character
// Boolean -> logical
// Float -> double
// Double -> double
// Long -> long
// Array[Byte] -> raw
// Date -> Date
// Time -> POSIXct
//
// Array[T] -> list()
// Object -> jobj
def writeType(dos: DataOutputStream, typeStr: String): Unit = {
typeStr match {
case "void" => dos.writeByte('n')
case "character" => dos.writeByte('c')
case "double" => dos.writeByte('d')
case "doublearray" => dos.writeByte('A')
case "long" => dos.writeByte('g')
case "integer" => dos.writeByte('i')
case "logical" => dos.writeByte('b')
case "date" => dos.writeByte('D')
case "time" => dos.writeByte('t')
case "raw" => dos.writeByte('r')
case "list" => dos.writeByte('l')
case "jobj" => dos.writeByte('j')
case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
}
}
def writeObject(dos: DataOutputStream, value: Object): Unit = {
if (value == null || value == Unit) {
writeType(dos, "void")
} else {
value.getClass.getName match {
case "java.lang.String" =>
writeType(dos, "character")
writeString(dos, value.asInstanceOf[String])
case "float" | "java.lang.Float" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Float].toDouble)
case "double" | "java.lang.Double" =>
writeType(dos, "double")
writeDouble(dos, value.asInstanceOf[Double])
case "long" | "java.lang.Long" =>
writeType(dos, "long")
writeLong(dos, value.asInstanceOf[Long])
case "int" | "java.lang.Integer" =>
writeType(dos, "integer")
writeInt(dos, value.asInstanceOf[Int])
case "boolean" | "java.lang.Boolean" =>
writeType(dos, "logical")
writeBoolean(dos, value.asInstanceOf[Boolean])
case "java.sql.Date" =>
writeType(dos, "date")
writeDate(dos, value.asInstanceOf[Date])
case "java.sql.Time" =>
writeType(dos, "time")
writeTime(dos, value.asInstanceOf[Time])
case "java.sql.Timestamp" =>
writeType(dos, "time")
writeTime(dos, value.asInstanceOf[Timestamp])
case "[B" =>
writeType(dos, "raw")
writeBytes(dos, value.asInstanceOf[Array[Byte]])
// TODO: Types not handled right now include
// byte, char, short, float
// Handle arrays
case "[Ljava.lang.String;" =>
writeType(dos, "list")
writeStringArr(dos, value.asInstanceOf[Array[String]])
case "[I" =>
writeType(dos, "list")
writeIntArr(dos, value.asInstanceOf[Array[Int]])
case "[J" =>
writeType(dos, "list")
writeLongArr(dos, value.asInstanceOf[Array[Long]])
case "[D" =>
writeType(dos, "list")
writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
case "[[D" =>
writeType(dos, "list")
writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]])
case "[Z" =>
writeType(dos, "list")
writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
case "[[B" =>
writeType(dos, "list")
writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
case otherName =>
// Handle array of objects
if (otherName.startsWith("[L")) {
val objArr = value.asInstanceOf[Array[Object]]
writeType(dos, "list")
writeType(dos, "jobj")
dos.writeInt(objArr.length)
objArr.foreach(o => writeJObj(dos, o))
} else {
writeType(dos, "jobj")
writeJObj(dos, value)
}
}
}
}
def writeInt(out: DataOutputStream, value: Int): Unit = {
out.writeInt(value)
}
def writeLong(out: DataOutputStream, value: Long): Unit = {
out.writeLong(value)
}
private def writeDouble(out: DataOutputStream, value: Double): Unit = {
out.writeDouble(value)
}
private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
out.writeBoolean(value)
}
private def writeDate(out: DataOutputStream, value: Date): Unit = {
writeString(out, value.toString)
}
private def writeTime(out: DataOutputStream, value: Time): Unit = {
out.writeDouble(value.getTime.toDouble / 1000.0)
}
private def writeTime(out: DataOutputStream, value: Timestamp): Unit = {
out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
}
def writeString(out: DataOutputStream, value: String): Unit = {
val utf8 = value.getBytes(StandardCharsets.UTF_8)
val len = utf8.length
out.writeInt(len)
out.write(utf8, 0, len)
}
private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
out.writeInt(value.length)
out.write(value)
}
def writeJObj(out: DataOutputStream, value: Object): Unit = {
val objId = tracker.put(value)
writeString(out, objId)
}
private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
writeType(out, "integer")
out.writeInt(value.length)
value.foreach(v => out.writeInt(v))
}
private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = {
writeType(out, "long")
out.writeInt(value.length)
value.foreach(v => out.writeLong(v))
}
private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
writeType(out, "double")
out.writeInt(value.length)
value.foreach(v => out.writeDouble(v))
}
private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = {
writeType(out, "doublearray")
out.writeInt(value.length)
value.foreach(v => writeDoubleArr(out, v))
}
private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
writeType(out, "logical")
out.writeInt(value.length)
value.foreach(v => writeBoolean(out, v))
}
private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
writeType(out, "character")
out.writeInt(value.length)
value.foreach(v => writeString(out, v))
}
private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
writeType(out, "raw")
out.writeInt(value.length)
value.foreach(v => writeBytes(out, v))
}
}
private object SerializationFormats {
val BYTE = "byte"
val STRING = "string"
val ROW = "row"
}

Просмотреть файл

@ -0,0 +1,72 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.util.concurrent.{ExecutorService, Executors}
import scala.collection.mutable
/**
* Pool of thread executors. There should be a 1-1 correspondence between C# threads
* and Java threads.
*/
object ThreadPool {
/**
* Map from (processId, threadId) to corresponding executor.
*/
private val executors: mutable.HashMap[(Int, Int), ExecutorService] =
new mutable.HashMap[(Int, Int), ExecutorService]()
/**
* Run some code on a particular thread.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @param task Function to run on the thread.
*/
def run(processId: Int, threadId: Int, task: () => Unit): Unit = {
val executor = getOrCreateExecutor(processId, threadId)
val future = executor.submit(new Runnable {
override def run(): Unit = task()
})
future.get()
}
/**
* Try to delete a particular thread.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return True if successful, false if thread does not exist.
*/
def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized {
executors.remove((processId, threadId)) match {
case Some(executorService) =>
executorService.shutdown()
true
case None => false
}
}
/**
* Shutdown any running ExecutorServices.
*/
def shutdown(): Unit = synchronized {
executors.foreach(_._2.shutdown())
executors.clear()
}
/**
* Get the executor if it exists, otherwise create a new one.
* @param processId Integer id of the process.
* @param threadId Integer id of the thread.
* @return The new or existing executor with the given id.
*/
private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized {
executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor)
}
}

Просмотреть файл

@ -0,0 +1,285 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.deploy.dotnet
import java.io.File
import java.net.URI
import java.nio.file.attribute.PosixFilePermissions
import java.nio.file.{FileSystems, Files, Paths}
import java.util.Locale
import java.util.concurrent.{Semaphore, TimeUnit}
import org.apache.commons.io.FilenameUtils
import org.apache.hadoop.fs.Path
import org.apache.spark
import org.apache.spark.api.dotnet.DotnetBackend
import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
import org.apache.spark.util.dotnet.{Utils => DotnetUtils}
import org.apache.spark.util.{RedirectThread, Utils}
import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException}
import scala.collection.JavaConverters._
import scala.io.StdIn
import scala.util.Try
/**
* DotnetRunner class used to launch Spark .NET applications using spark-submit.
* It executes .NET application as a subprocess and then has it connect back to
* the JVM to access system properties etc.
*/
object DotnetRunner extends Logging {
private val DEBUG_PORT = 5567
private val supportedSparkMajorMinorVersionPrefix = "3.2"
private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1")
val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION)
def main(args: Array[String]): Unit = {
if (args.length == 0) {
throw new IllegalArgumentException("At least one argument is expected.")
}
DotnetUtils.validateSparkVersions(
sys.props
.getOrElse(
DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key,
DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString)
.toBoolean,
spark.SPARK_VERSION,
SPARK_VERSION,
supportedSparkMajorMinorVersionPrefix,
supportedSparkVersions)
val settings = initializeSettings(args)
// Determines if this needs to be run in debug mode.
// In debug mode this runner will not launch a .NET process.
val runInDebugMode = settings._1
@volatile var dotnetBackendPortNumber = settings._2
var dotnetExecutable = ""
var otherArgs: Array[String] = null
if (!runInDebugMode) {
if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) {
var zipFileName = args(0)
val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI)
val workingDir = new File("").getAbsoluteFile
val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath()))
// Standalone cluster mode where .NET application is remotely located.
if (zipFileUri.getScheme() != "file") {
zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName
}
logInfo(s"Unzipping .NET driver $zipFileName to $driverDir")
DotnetUtils.unzip(new File(zipFileName), driverDir)
// Reuse windows-specific formatting in PythonRunner.
dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1)))
otherArgs = args.slice(2, args.length)
} else {
// Reuse windows-specific formatting in PythonRunner.
dotnetExecutable = PythonRunner.formatPath(args(0))
otherArgs = args.slice(1, args.length)
}
} else {
otherArgs = args.slice(1, args.length)
}
val processParameters = new java.util.ArrayList[String]
processParameters.add(dotnetExecutable)
otherArgs.foreach(arg => processParameters.add(arg))
logInfo(s"Starting DotnetBackend with $dotnetExecutable.")
// Time to wait for DotnetBackend to initialize in seconds.
val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt
// Launch a DotnetBackend server for the .NET process to connect to; this will let it see our
// Java system properties etc.
val dotnetBackend = new DotnetBackend()
val initialized = new Semaphore(0)
val dotnetBackendThread = new Thread("DotnetBackend") {
override def run() {
// need to get back dotnetBackendPortNumber because if the value passed to init is 0
// the port number is dynamically assigned in the backend
dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber)
logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber")
initialized.release()
dotnetBackend.run()
}
}
dotnetBackendThread.start()
if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
if (!runInDebugMode) {
var returnCode = -1
var process: Process = null
try {
val builder = new ProcessBuilder(processParameters)
val env = builder.environment()
env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString)
for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
env.put(key, value)
logInfo(s"Adding key=$key and value=$value to environment")
}
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
process = builder.start()
// Redirect stdin of JVM process to stdin of .NET process.
new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start()
// Redirect stdout and stderr of .NET process.
new RedirectThread(process.getInputStream, System.out, "redirect .NET stdout").start()
new RedirectThread(process.getErrorStream, System.out, "redirect .NET stderr").start()
process.waitFor()
} catch {
case t: Throwable =>
logThrowable(t)
} finally {
returnCode = closeDotnetProcess(process)
closeBackend(dotnetBackend)
}
if (returnCode != 0) {
throw new SparkUserAppException(returnCode)
} else {
logInfo(s".NET application exited successfully")
}
// TODO: The following is causing the following error:
// INFO ApplicationMaster: Final app status: FAILED, exitCode: 16,
// (reason: Shutdown hook called before final status was reported.)
// DotnetUtils.exit(returnCode)
} else {
// scalastyle:off println
println("***********************************************************************")
println("* .NET Backend running debug mode. Press enter to exit *")
println("***********************************************************************")
// scalastyle:on println
StdIn.readLine()
closeBackend(dotnetBackend)
DotnetUtils.exit(0)
}
} else {
logError(s"DotnetBackend did not initialize in $backendTimeout seconds")
DotnetUtils.exit(-1)
}
}
// When the executable is downloaded as part of zip file, check if the file exists
// after zip file is unzipped under the given dir. Once it is found, change the
// permission to executable (only for Unix systems, since the zip file may have been
// created under Windows. Finally, the absolute path for the executable is returned.
private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = {
val path = Paths.get(dir.getAbsolutePath, dotnetExecutable)
val resolvedExecutable = if (Files.isRegularFile(path)) {
path.toAbsolutePath.toString
} else {
Files
.walk(FileSystems.getDefault.getPath(dir.getAbsolutePath))
.iterator()
.asScala
.find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match {
case Some(path) => path.toAbsolutePath.toString
case None =>
throw new IllegalArgumentException(
s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}")
}
}
if (DotnetUtils.supportPosix) {
Files.setPosixFilePermissions(
Paths.get(resolvedExecutable),
PosixFilePermissions.fromString("rwxr-xr-x"))
}
resolvedExecutable
}
/**
* Download HDFS file into the supplied directory and return its local path.
* Will throw an exception if there are errors during downloading.
*/
private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = {
val sparkConf = new SparkConf()
val filePath = new Path(hdfsFilePath)
val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
val jarFileName = filePath.getName
val localFile = new File(driverDir, jarFileName)
if (!localFile.exists()) { // May already exist if running multiple workers on one node
logInfo(s"Copying user file $filePath to $driverDir")
Utils.fetchFile(
hdfsFilePath,
new File(driverDir),
sparkConf,
hadoopConf,
System.currentTimeMillis(),
useCache = false)
}
if (!localFile.exists()) {
throw new Exception(s"Did not see expected $jarFileName in $driverDir")
}
localFile
}
private def closeBackend(dotnetBackend: DotnetBackend): Unit = {
logInfo("Closing DotnetBackend")
dotnetBackend.close()
}
private def closeDotnetProcess(dotnetProcess: Process): Int = {
if (dotnetProcess == null) {
return -1
} else if (!dotnetProcess.isAlive) {
return dotnetProcess.exitValue()
}
// Try to (gracefully on Linux) kill the process and resort to force if interrupted
var returnCode = -1
logInfo("Closing .NET process")
try {
dotnetProcess.destroy()
returnCode = dotnetProcess.waitFor()
} catch {
case _: InterruptedException =>
logInfo(
"Thread interrupted while waiting for graceful close. Forcefully closing .NET process")
returnCode = dotnetProcess.destroyForcibly().waitFor()
case t: Throwable =>
logThrowable(t)
}
returnCode
}
private def initializeSettings(args: Array[String]): (Boolean, Int) = {
val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase(
"debug")
var portNumber = 0
if (runInDebugMode) {
if (args.length == 1) {
portNumber = DEBUG_PORT
} else if (args.length == 2) {
portNumber = Integer.parseInt(args(1))
}
}
(runInDebugMode, portNumber)
}
private def logThrowable(throwable: Throwable): Unit =
logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}")
}

Просмотреть файл

@ -0,0 +1,18 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.internal.config.dotnet
import org.apache.spark.internal.config.ConfigBuilder
private[spark] object Dotnet {
val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf
.createWithDefault(10)
val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK =
ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf
.createWithDefault(false)
}

Просмотреть файл

@ -0,0 +1,33 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.sql.api.dotnet
import org.apache.spark.api.dotnet.CallbackClient
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.streaming.DataStreamWriter
class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging {
def call(batchDF: DataFrame, batchId: Long): Unit =
callbackClient.send(
callbackId,
(dos, serDe) => {
serDe.writeJObj(dos, batchDF)
serDe.writeLong(dos, batchId)
})
}
object DotnetForeachBatchHelper {
def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = {
val dotnetForeachFunc = client match {
case Some(value) => new DotnetForeachBatchFunction(value, callbackId)
case None => throw new Exception("CallbackClient is null.")
}
dsw.foreachBatch(dotnetForeachFunc.call _)
}
}

Просмотреть файл

@ -0,0 +1,37 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.sql.api.dotnet
import java.util.{List => JList, Map => JMap}
import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction}
import org.apache.spark.broadcast.Broadcast
object SQLUtils {
/**
* Exposes createPythonFunction to the .NET client to enable registering UDFs.
*/
def createPythonFunction(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVersion: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: PythonAccumulatorV2): PythonFunction = {
PythonFunction(
command,
envVars,
pythonIncludes,
pythonExec,
pythonVersion,
broadcastVars,
accumulator)
}
}

Просмотреть файл

@ -0,0 +1,30 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.sql.test
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.MemoryStream
object TestUtils {
/**
* Helper method to create typed MemoryStreams intended for use in unit tests.
* @param sqlContext The SQLContext.
* @param streamType The type of memory stream to create. This string is the `Name`
* property of the dotnet type.
* @return A typed MemoryStream.
*/
def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = {
import sqlContext.implicits._
streamType match {
case "Int32" => MemoryStream[Int]
case "String" => MemoryStream[String]
case _ => throw new Exception(s"$streamType not supported")
}
}
}

Просмотреть файл

@ -0,0 +1,254 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.util.dotnet
import java.io._
import java.nio.file.attribute.PosixFilePermission
import java.nio.file.attribute.PosixFilePermission._
import java.nio.file.{FileSystems, Files}
import java.util.{Timer, TimerTask}
import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile}
import org.apache.commons.io.{FileUtils, IOUtils}
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
import scala.collection.JavaConverters._
import scala.collection.Set
/**
* Utility methods.
*/
object Utils extends Logging {
private val posixFilePermissions = Array(
OWNER_READ,
OWNER_WRITE,
OWNER_EXECUTE,
GROUP_READ,
GROUP_WRITE,
GROUP_EXECUTE,
OTHERS_READ,
OTHERS_WRITE,
OTHERS_EXECUTE)
val supportPosix: Boolean =
FileSystems.getDefault.supportedFileAttributeViews().contains("posix")
/**
* Compress all files under given directory into one zip file and drop it to the target directory
*
* @param sourceDir source directory to zip
* @param targetZipFile target zip file
*/
def zip(sourceDir: File, targetZipFile: File): Unit = {
var fos: FileOutputStream = null
var zos: ZipArchiveOutputStream = null
try {
fos = new FileOutputStream(targetZipFile)
zos = new ZipArchiveOutputStream(fos)
val sourcePath = sourceDir.toPath
FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file =>
var in: FileInputStream = null
try {
val path = file.toPath
val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString)
if (supportPosix) {
entry.setUnixMode(
permissionsToMode(Files.getPosixFilePermissions(path).asScala)
| (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4))
} else if (entry.getName.endsWith(".exe")) {
entry.setUnixMode(0x1ED) // 755
} else {
entry.setUnixMode(0x1A4) // 644
}
zos.putArchiveEntry(entry)
in = new FileInputStream(file)
IOUtils.copy(in, zos)
zos.closeArchiveEntry()
} finally {
IOUtils.closeQuietly(in)
}
}
} finally {
IOUtils.closeQuietly(zos)
IOUtils.closeQuietly(fos)
}
}
/**
* Unzip a file to the given directory
*
* @param file file to be unzipped
* @param targetDir target directory
*/
def unzip(file: File, targetDir: File): Unit = {
var zipFile: ZipFile = null
try {
targetDir.mkdirs()
zipFile = new ZipFile(file)
zipFile.getEntries.asScala.foreach { entry =>
val targetFile = new File(targetDir, entry.getName)
if (targetFile.exists()) {
logWarning(
s"Target file/directory $targetFile already exists. Skip it for now. " +
s"Make sure this is expected.")
} else {
if (entry.isDirectory) {
targetFile.mkdirs()
} else {
targetFile.getParentFile.mkdirs()
val input = zipFile.getInputStream(entry)
val output = new FileOutputStream(targetFile)
IOUtils.copy(input, output)
IOUtils.closeQuietly(input)
IOUtils.closeQuietly(output)
if (supportPosix) {
val permissions = modeToPermissions(entry.getUnixMode)
// When run in Unix system, permissions will be empty, thus skip
// setting the empty permissions (which will empty the previous permissions).
if (permissions.nonEmpty) {
Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava)
}
}
}
}
}
} catch {
case e: Exception => logError("exception caught during decompression:" + e)
} finally {
ZipFile.closeQuietly(zipFile)
}
}
/**
* Exits the JVM, trying to do it nicely, otherwise doing it nastily.
*
* @param status the exit status, zero for OK, non-zero for error
* @param maxDelayMillis the maximum delay in milliseconds
*/
def exit(status: Int, maxDelayMillis: Long) {
try {
logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis")
// setup a timer, so if nice exit fails, the nasty exit happens
val timer = new Timer()
timer.schedule(new TimerTask() {
override def run() {
Runtime.getRuntime.halt(status)
}
}, maxDelayMillis)
// try to exit nicely
System.exit(status);
} catch {
// exit nastily if we have a problem
case _: Throwable => Runtime.getRuntime.halt(status)
} finally {
// should never get here
Runtime.getRuntime.halt(status)
}
}
/**
* Exits the JVM, trying to do it nicely, wait 1 second
*
* @param status the exit status, zero for OK, non-zero for error
*/
def exit(status: Int): Unit = {
exit(status, 1000)
}
/**
* Normalize the Spark version by taking the first three numbers.
* For example:
* x.y.z => x.y.z
* x.y.z.xxx.yyy => x.y.z
* x.y => x.y
*
* @param version the Spark version to normalize
* @return Normalized Spark version.
*/
def normalizeSparkVersion(version: String): String = {
version
.split('.')
.take(3)
.zipWithIndex
.map({
case (element, index) => {
index match {
case 2 => element.split("\\D+").lift(0).getOrElse("")
case _ => element
}
}
})
.mkString(".")
}
/**
* Validates the normalized spark version by verifying:
* - Spark version starts with sparkMajorMinorVersionPrefix.
* - If ignoreSparkPatchVersion is
* - true: valid
* - false: check if the spark version is in supportedSparkVersions.
* @param ignoreSparkPatchVersion Ignore spark patch version.
* @param sparkVersion The spark version.
* @param normalizedSparkVersion: The normalized spark version.
* @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against.
* @param supportedSparkVersions The set of supported spark versions.
*/
def validateSparkVersions(
ignoreSparkPatchVersion: Boolean,
sparkVersion: String,
normalizedSparkVersion: String,
supportedSparkMajorMinorVersionPrefix: String,
supportedSparkVersions: Set[String]): Unit = {
if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) {
throw new IllegalArgumentException(
s"Unsupported spark version used: '$sparkVersion'. " +
s"Normalized spark version used: '$normalizedSparkVersion'. " +
s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.")
} else if (ignoreSparkPatchVersion) {
logWarning(
s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " +
s"Normalized spark version used: '$normalizedSparkVersion'. " +
s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.")
} else if (!supportedSparkVersions(normalizedSparkVersion)) {
val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ")
throw new IllegalArgumentException(
s"Unsupported spark version used: '$sparkVersion'. " +
s"Normalized spark version used: '$normalizedSparkVersion'. " +
s"Supported versions: '$supportedVersions'.")
}
}
private[spark] def listZipFileEntries(file: File): Array[String] = {
var zipFile: ZipFile = null
try {
zipFile = new ZipFile(file)
zipFile.getEntries.asScala.map(_.getName).toArray
} finally {
ZipFile.closeQuietly(zipFile)
}
}
private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = {
posixFilePermissions.foldLeft(0) { (mode, perm) =>
(mode << 1) | (if (permissions.contains(perm)) 1 else 0)
}
}
private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = {
posixFilePermissions.zipWithIndex
.filter { case (_, i) => (mode & (0x100 >>> i)) != 0 }
.map(_._1)
.toSet
}
}

Просмотреть файл

@ -0,0 +1,68 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import Extensions._
import org.junit.Assert._
import org.junit.{After, Before, Test}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
@Test
class DotnetBackendHandlerTest {
private var backend: DotnetBackend = _
private var tracker: JVMObjectTracker = _
private var handler: DotnetBackendHandler = _
@Before
def before(): Unit = {
backend = new DotnetBackend
tracker = new JVMObjectTracker
handler = new DotnetBackendHandler(backend, tracker)
}
@After
def after(): Unit = {
backend.close()
}
@Test
def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = {
val message = givenMessage(m => {
val serDe = new SerDe(null)
m.writeBoolean(true) // static method
serDe.writeInt(m, 1) // processId
serDe.writeInt(m, 1) // threadId
serDe.writeString(m, "DotnetHandler") // class name
serDe.writeString(m, "connectCallback") // command (method) name
m.writeInt(2) // number of arguments
m.writeByte('c') // 1st argument type (string)
serDe.writeString(m, "127.0.0.1") // 1st argument value (host)
m.writeByte('i') // 2nd argument type (integer)
m.writeInt(0) // 2nd argument value (port)
})
val payload = handler.handleBackendRequest(message)
val reply = new DataInputStream(new ByteArrayInputStream(payload))
assertEquals(
"status code must be successful.", 0, reply.readInt())
assertEquals('j', reply.readByte())
assertEquals(1, reply.readInt())
val trackingId = new String(reply.readNBytes(1), "UTF-8")
assertEquals("1", trackingId)
val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull
assertEquals(classOf[CallbackClient], client.getClass)
}
private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = {
val buffer = new ByteArrayOutputStream()
func(new DataOutputStream(buffer))
buffer.toByteArray
}
}

Просмотреть файл

@ -0,0 +1,39 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import org.junit.Assert._
import org.junit.{After, Before, Test}
import java.net.InetAddress
@Test
class DotnetBackendTest {
private var backend: DotnetBackend = _
@Before
def before(): Unit = {
backend = new DotnetBackend
}
@After
def after(): Unit = {
backend.close()
}
@Test
def shouldNotResetCallbackClient(): Unit = {
// Specifying port = 0 to select port dynamically.
backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
assertTrue(backend.callbackClient.isDefined)
assertThrows(classOf[Exception], () => {
backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0)
})
}
}

Просмотреть файл

@ -0,0 +1,20 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import java.io.DataInputStream
private[dotnet] object Extensions {
implicit class DataInputStreamExt(stream: DataInputStream) {
def readNBytes(n: Int): Array[Byte] = {
val buf = new Array[Byte](n)
stream.readFully(buf)
buf
}
}
}

Просмотреть файл

@ -0,0 +1,42 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import org.junit.Test
@Test
class JVMObjectTrackerTest {
@Test
def shouldReleaseAllReferences(): Unit = {
val tracker = new JVMObjectTracker
val firstId = tracker.put(new Object)
val secondId = tracker.put(new Object)
val thirdId = tracker.put(new Object)
tracker.clear()
assert(tracker.get(firstId).isEmpty)
assert(tracker.get(secondId).isEmpty)
assert(tracker.get(thirdId).isEmpty)
}
@Test
def shouldResetCounter(): Unit = {
val tracker = new JVMObjectTracker
val firstId = tracker.put(new Object)
val secondId = tracker.put(new Object)
tracker.clear()
val thirdId = tracker.put(new Object)
assert(firstId.equals("1"))
assert(secondId.equals("2"))
assert(thirdId.equals("1"))
}
}

Просмотреть файл

@ -0,0 +1,373 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.api.dotnet
import org.apache.spark.api.dotnet.Extensions._
import org.apache.spark.sql.Row
import org.junit.Assert._
import org.junit.{Before, Test}
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.sql.Date
import scala.collection.JavaConverters._
@Test
class SerDeTest {
private var serDe: SerDe = _
private var tracker: JVMObjectTracker = _
@Before
def before(): Unit = {
tracker = new JVMObjectTracker
serDe = new SerDe(tracker)
}
@Test
def shouldReadNull(): Unit = {
val input = givenInput(in => {
in.writeByte('n')
})
assertEquals(null, serDe.readObject(input))
}
@Test
def shouldThrowForUnsupportedTypes(): Unit = {
val input = givenInput(in => {
in.writeByte('_')
})
assertThrows(classOf[IllegalArgumentException], () => {
serDe.readObject(input)
})
}
@Test
def shouldReadInteger(): Unit = {
val input = givenInput(in => {
in.writeByte('i')
in.writeInt(42)
})
assertEquals(42, serDe.readObject(input))
}
@Test
def shouldReadLong(): Unit = {
val input = givenInput(in => {
in.writeByte('g')
in.writeLong(42)
})
assertEquals(42L, serDe.readObject(input))
}
@Test
def shouldReadDouble(): Unit = {
val input = givenInput(in => {
in.writeByte('d')
in.writeDouble(42.42)
})
assertEquals(42.42, serDe.readObject(input))
}
@Test
def shouldReadBoolean(): Unit = {
val input = givenInput(in => {
in.writeByte('b')
in.writeBoolean(true)
})
assertEquals(true, serDe.readObject(input))
}
@Test
def shouldReadString(): Unit = {
val payload = "Spark Dotnet"
val input = givenInput(in => {
in.writeByte('c')
in.writeInt(payload.getBytes("UTF-8").length)
in.write(payload.getBytes("UTF-8"))
})
assertEquals(payload, serDe.readObject(input))
}
@Test
def shouldReadMap(): Unit = {
val input = givenInput(in => {
in.writeByte('e') // map type descriptor
in.writeInt(3) // size
in.writeByte('i') // key type
in.writeInt(3) // number of keys
in.writeInt(11) // first key
in.writeInt(22) // second key
in.writeInt(33) // third key
in.writeInt(3) // number of values
in.writeByte('b') // first value type
in.writeBoolean(true) // first value
in.writeByte('d') // second value type
in.writeDouble(42.42) // second value
in.writeByte('n') // third type & value
})
assertEquals(
mapAsJavaMap(Map(
11 -> true,
22 -> 42.42,
33 -> null)),
serDe.readObject(input))
}
@Test
def shouldReadEmptyMap(): Unit = {
val input = givenInput(in => {
in.writeByte('e') // map type descriptor
in.writeInt(0) // size
})
assertEquals(mapAsJavaMap(Map()), serDe.readObject(input))
}
@Test
def shouldReadBytesArray(): Unit = {
val input = givenInput(in => {
in.writeByte('r') // byte array type descriptor
in.writeInt(3) // length
in.write(Array[Byte](1, 2, 3)) // payload
})
assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]])
}
@Test
def shouldReadEmptyBytesArray(): Unit = {
val input = givenInput(in => {
in.writeByte('r') // byte array type descriptor
in.writeInt(0) // length
})
assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]])
}
@Test
def shouldReadEmptyList(): Unit = {
val input = givenInput(in => {
in.writeByte('l') // type descriptor
in.writeByte('i') // element type
in.writeInt(0) // length
})
assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]])
}
@Test
def shouldReadList(): Unit = {
val input = givenInput(in => {
in.writeByte('l') // type descriptor
in.writeByte('b') // element type
in.writeInt(3) // length
in.writeBoolean(true)
in.writeBoolean(false)
in.writeBoolean(true)
})
assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]])
}
@Test
def shouldThrowWhenReadingListWithUnsupportedType(): Unit = {
val input = givenInput(in => {
in.writeByte('l') // type descriptor
in.writeByte('_') // unsupported element type
})
assertThrows(classOf[IllegalArgumentException], () => {
serDe.readObject(input)
})
}
@Test
def shouldReadDate(): Unit = {
val input = givenInput(in => {
val date = "2020-12-31"
in.writeByte('D') // type descriptor
in.writeInt(date.getBytes("UTF-8").length) // date string size
in.write(date.getBytes("UTF-8"))
})
assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input))
}
@Test
def shouldReadObject(): Unit = {
val trackingObject = new Object
tracker.put(trackingObject)
val input = givenInput(in => {
val objectIndex = "1"
in.writeByte('j') // type descriptor
in.writeInt(objectIndex.getBytes("UTF-8").length) // size
in.write(objectIndex.getBytes("UTF-8"))
})
assertSame(trackingObject, serDe.readObject(input))
}
@Test
def shouldThrowWhenReadingNonTrackingObject(): Unit = {
val input = givenInput(in => {
val objectIndex = "42"
in.writeByte('j') // type descriptor
in.writeInt(objectIndex.getBytes("UTF-8").length) // size
in.write(objectIndex.getBytes("UTF-8"))
})
assertThrows(classOf[NoSuchElementException], () => {
serDe.readObject(input)
})
}
@Test
def shouldReadSparkRows(): Unit = {
val input = givenInput(in => {
in.writeByte('R') // type descriptor
in.writeInt(2) // number of rows
in.writeInt(1) // number of elements in 1st row
in.writeByte('i') // type of 1st element in 1st row
in.writeInt(11)
in.writeInt(3) // number of elements in 2st row
in.writeByte('b') // type of 1st element in 2nd row
in.writeBoolean(true)
in.writeByte('d') // type of 2nd element in 2nd row
in.writeDouble(42.24)
in.writeByte('g') // type of 3nd element in 2nd row
in.writeLong(99)
})
assertEquals(
seqAsJavaList(Seq(
Row.fromSeq(Seq(11)),
Row.fromSeq(Seq(true, 42.24, 99)))),
serDe.readObject(input))
}
@Test
def shouldReadArrayOfObjects(): Unit = {
val input = givenInput(in => {
in.writeByte('O') // type descriptor
in.writeInt(2) // number of elements
in.writeByte('i') // type of 1st element
in.writeInt(42)
in.writeByte('b') // type of 2nd element
in.writeBoolean(true)
})
assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]])
}
@Test
def shouldWriteNull(): Unit = {
val in = whenOutput(out => {
serDe.writeObject(out, null)
serDe.writeObject(out, Unit)
})
assertEquals(in.readByte(), 'n')
assertEquals(in.readByte(), 'n')
assertEndOfStream(in)
}
@Test
def shouldWriteString(): Unit = {
val sparkDotnet = "Spark Dotnet"
val in = whenOutput(out => {
serDe.writeObject(out, sparkDotnet)
})
assertEquals(in.readByte(), 'c') // object type
assertEquals(in.readInt(), sparkDotnet.length) // length
assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8"))
assertEndOfStream(in)
}
@Test
def shouldWritePrimitiveTypes(): Unit = {
val in = whenOutput(out => {
serDe.writeObject(out, 42.24f.asInstanceOf[Object])
serDe.writeObject(out, 42L.asInstanceOf[Object])
serDe.writeObject(out, 42.asInstanceOf[Object])
serDe.writeObject(out, true.asInstanceOf[Object])
})
assertEquals(in.readByte(), 'd')
assertEquals(in.readDouble(), 42.24F, 0.000001)
assertEquals(in.readByte(), 'g')
assertEquals(in.readLong(), 42L)
assertEquals(in.readByte(), 'i')
assertEquals(in.readInt(), 42)
assertEquals(in.readByte(), 'b')
assertEquals(in.readBoolean(), true)
assertEndOfStream(in)
}
@Test
def shouldWriteDate(): Unit = {
val date = "2020-12-31"
val in = whenOutput(out => {
serDe.writeObject(out, Date.valueOf(date))
})
assertEquals(in.readByte(), 'D') // type
assertEquals(in.readInt(), 10) // size
assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content
}
@Test
def shouldWriteCustomObjects(): Unit = {
val customObject = new Object
val in = whenOutput(out => {
serDe.writeObject(out, customObject)
})
assertEquals(in.readByte(), 'j')
assertEquals(in.readInt(), 1)
assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8"))
assertSame(tracker.get("1").get, customObject)
}
@Test
def shouldWriteArrayOfCustomObjects(): Unit = {
val payload = Array(new Object, new Object)
val in = whenOutput(out => {
serDe.writeObject(out, payload)
})
assertEquals(in.readByte(), 'l') // array type
assertEquals(in.readByte(), 'j') // type of element in array
assertEquals(in.readInt(), 2) // array length
assertEquals(in.readInt(), 1) // size of 1st element's identifiers
assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element
assertEquals(in.readInt(), 1) // size of 2nd element's identifier
assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element
assertSame(tracker.get("1").get, payload(0))
assertSame(tracker.get("2").get, payload(1))
}
private def givenInput(func: DataOutputStream => Unit): DataInputStream = {
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(buffer)
func(out)
new DataInputStream(new ByteArrayInputStream(buffer.toByteArray))
}
private def whenOutput = givenInput _
private def assertEndOfStream (in: DataInputStream): Unit = {
assertEquals(-1, in.read())
}
}

Просмотреть файл

@ -0,0 +1,80 @@
/*
* Licensed to the .NET Foundation under one or more agreements.
* The .NET Foundation licenses this file to you under the MIT license.
* See the LICENSE file in the project root for more information.
*/
package org.apache.spark.util.dotnet
import org.apache.spark.SparkConf
import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK
import org.junit.Assert.{assertEquals, assertThrows}
import org.junit.Test
@Test
class UtilsTest {
@Test
def shouldIgnorePatchVersion(): Unit = {
val sparkVersion = "3.2.1"
val sparkMajorMinorVersionPrefix = "3.2"
val supportedSparkVersions = Set[String]("3.2.0")
Utils.validateSparkVersions(
true,
sparkVersion,
Utils.normalizeSparkVersion(sparkVersion),
sparkMajorMinorVersionPrefix,
supportedSparkVersions)
}
@Test
def shouldThrowForUnsupportedVersion(): Unit = {
val sparkVersion = "3.2.1"
val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
val sparkMajorMinorVersionPrefix = "3.2"
val supportedSparkVersions = Set[String]("3.2.0")
val exception = assertThrows(
classOf[IllegalArgumentException],
() => {
Utils.validateSparkVersions(
false,
sparkVersion,
normalizedSparkVersion,
sparkMajorMinorVersionPrefix,
supportedSparkVersions)
})
assertEquals(
s"Unsupported spark version used: '$sparkVersion'. " +
s"Normalized spark version used: '$normalizedSparkVersion'. " +
s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.",
exception.getMessage)
}
@Test
def shouldThrowForUnsupportedMajorMinorVersion(): Unit = {
val sparkVersion = "2.4.4"
val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion)
val sparkMajorMinorVersionPrefix = "3.2"
val supportedSparkVersions = Set[String]("3.2.0")
val exception = assertThrows(
classOf[IllegalArgumentException],
() => {
Utils.validateSparkVersions(
false,
sparkVersion,
normalizedSparkVersion,
sparkMajorMinorVersionPrefix,
supportedSparkVersions)
})
assertEquals(
s"Unsupported spark version used: '$sparkVersion'. " +
s"Normalized spark version used: '$normalizedSparkVersion'. " +
s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.",
exception.getMessage)
}
}

Просмотреть файл

@ -14,6 +14,7 @@
<module>microsoft-spark-2-4</module>
<module>microsoft-spark-3-0</module>
<module>microsoft-spark-3-1</module>
<module>microsoft-spark-3-2</module>
</modules>
<pluginRepositories>