Update worker contract for Spark-3.0. (#311)

This commit is contained in:
Terry Kim 2019-11-21 07:57:58 -08:00 коммит произвёл GitHub
Родитель e82abf3ffb
Коммит 40c9febd54
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 784 добавлений и 719 удалений

Просмотреть файл

@ -31,76 +31,74 @@ namespace Microsoft.Spark.Extensions.Delta.E2ETest
[SkipIfSparkVersionIsLessThan(Versions.V2_4_2)]
public void TestTutorialScenario()
{
using (var tempDirectory = new TemporaryDirectory())
{
string path = Path.Combine(tempDirectory.Path, "delta-table");
using var tempDirectory = new TemporaryDirectory();
string path = Path.Combine(tempDirectory.Path, "delta-table");
// Write data to a Delta table.
DataFrame data = _spark.Range(0, 5);
data.Write().Format("delta").Save(path);
// Write data to a Delta table.
DataFrame data = _spark.Range(0, 5);
data.Write().Format("delta").Save(path);
// Validate that data contains the the sequence [0 ... 4].
ValidateRangeDataFrame(Enumerable.Range(0, 5), data);
// Validate that data contains the the sequence [0 ... 4].
ValidateRangeDataFrame(Enumerable.Range(0, 5), data);
// Create a second iteration of the table.
data = _spark.Range(5, 10);
data.Write().Format("delta").Mode("overwrite").Save(path);
// Create a second iteration of the table.
data = _spark.Range(5, 10);
data.Write().Format("delta").Mode("overwrite").Save(path);
// Load the data into a DeltaTable object.
var deltaTable = DeltaTable.ForPath(path);
// Load the data into a DeltaTable object.
var deltaTable = DeltaTable.ForPath(path);
// Validate that deltaTable contains the the sequence [5 ... 9].
ValidateRangeDataFrame(Enumerable.Range(5, 5), deltaTable.ToDF());
// Validate that deltaTable contains the the sequence [5 ... 9].
ValidateRangeDataFrame(Enumerable.Range(5, 5), deltaTable.ToDF());
// Update every even value by adding 100 to it.
deltaTable.Update(
condition: Functions.Expr("id % 2 == 0"),
set: new Dictionary<string, Column>() {
// Update every even value by adding 100 to it.
deltaTable.Update(
condition: Functions.Expr("id % 2 == 0"),
set: new Dictionary<string, Column>() {
{ "id", Functions.Expr("id + 100") }
});
});
// Validate that deltaTable contains the the data:
// +---+
// | id|
// +---+
// | 5|
// | 7|
// | 9|
// |106|
// |108|
// +---+
ValidateRangeDataFrame(
new List<int>() { 5, 7, 9, 106, 108 },
deltaTable.ToDF());
// Validate that deltaTable contains the the data:
// +---+
// | id|
// +---+
// | 5|
// | 7|
// | 9|
// |106|
// |108|
// +---+
ValidateRangeDataFrame(
new List<int>() { 5, 7, 9, 106, 108 },
deltaTable.ToDF());
// Delete every even value.
deltaTable.Delete(condition: Functions.Expr("id % 2 == 0"));
// Delete every even value.
deltaTable.Delete(condition: Functions.Expr("id % 2 == 0"));
// Validate that deltaTable contains:
// +---+
// | id|
// +---+
// | 5|
// | 7|
// | 9|
// +---+
ValidateRangeDataFrame(new List<int>() { 5, 7, 9 }, deltaTable.ToDF());
// Validate that deltaTable contains:
// +---+
// | id|
// +---+
// | 5|
// | 7|
// | 9|
// +---+
ValidateRangeDataFrame(new List<int>() { 5, 7, 9 }, deltaTable.ToDF());
// Upsert (merge) new data.
DataFrame newData = _spark.Range(0, 20).As("newData").ToDF();
// Upsert (merge) new data.
DataFrame newData = _spark.Range(0, 20).As("newData").ToDF();
deltaTable.As("oldData")
.Merge(newData, "oldData.id = newData.id")
.WhenMatched()
.Update(
new Dictionary<string, Column>() { { "id", Functions.Col("newData.id") } })
.WhenNotMatched()
.InsertExpr(new Dictionary<string, string>() { { "id", "newData.id" } })
.Execute();
deltaTable.As("oldData")
.Merge(newData, "oldData.id = newData.id")
.WhenMatched()
.Update(
new Dictionary<string, Column>() { { "id", Functions.Col("newData.id") } })
.WhenNotMatched()
.InsertExpr(new Dictionary<string, string>() { { "id", "newData.id" } })
.Execute();
// Validate that the resulTable contains the the sequence [0 ... 19].
ValidateRangeDataFrame(Enumerable.Range(0, 20), deltaTable.ToDF());
}
// Validate that the resulTable contains the the sequence [0 ... 19].
ValidateRangeDataFrame(Enumerable.Range(0, 20), deltaTable.ToDF());
}
/// <summary>
@ -109,39 +107,37 @@ namespace Microsoft.Spark.Extensions.Delta.E2ETest
[SkipIfSparkVersionIsLessThan(Versions.V2_4_2)]
public void TestStreamingScenario()
{
using (var tempDirectory = new TemporaryDirectory())
{
// Write [0, 1, 2, 3, 4] to a Delta table.
string sourcePath = Path.Combine(tempDirectory.Path, "source-delta-table");
_spark.Range(0, 5).Write().Format("delta").Save(sourcePath);
using var tempDirectory = new TemporaryDirectory();
// Write [0, 1, 2, 3, 4] to a Delta table.
string sourcePath = Path.Combine(tempDirectory.Path, "source-delta-table");
_spark.Range(0, 5).Write().Format("delta").Save(sourcePath);
// Create a stream from the source DeltaTable to the sink DeltaTable.
// To make the test synchronous and deterministic, we will use a series of
// "one-time micro-batch" triggers.
string sinkPath = Path.Combine(tempDirectory.Path, "sink-delta-table");
DataStreamWriter dataStreamWriter = _spark
.ReadStream()
.Format("delta")
.Load(sourcePath)
.WriteStream()
.Format("delta")
.OutputMode("append")
.Option("checkpointLocation", Path.Combine(tempDirectory.Path, "checkpoints"));
// Create a stream from the source DeltaTable to the sink DeltaTable.
// To make the test synchronous and deterministic, we will use a series of
// "one-time micro-batch" triggers.
string sinkPath = Path.Combine(tempDirectory.Path, "sink-delta-table");
DataStreamWriter dataStreamWriter = _spark
.ReadStream()
.Format("delta")
.Load(sourcePath)
.WriteStream()
.Format("delta")
.OutputMode("append")
.Option("checkpointLocation", Path.Combine(tempDirectory.Path, "checkpoints"));
// Trigger the first stream batch
dataStreamWriter.Trigger(Trigger.Once()).Start(sinkPath).AwaitTermination();
// Trigger the first stream batch
dataStreamWriter.Trigger(Trigger.Once()).Start(sinkPath).AwaitTermination();
// Now read the sink DeltaTable and validate its content.
DeltaTable sink = DeltaTable.ForPath(sinkPath);
ValidateRangeDataFrame(Enumerable.Range(0, 5), sink.ToDF());
// Now read the sink DeltaTable and validate its content.
DeltaTable sink = DeltaTable.ForPath(sinkPath);
ValidateRangeDataFrame(Enumerable.Range(0, 5), sink.ToDF());
// Write [5,6,7,8,9] to the source and trigger another stream batch.
_spark.Range(5, 10).Write().Format("delta").Mode("append").Save(sourcePath);
dataStreamWriter.Trigger(Trigger.Once()).Start(sinkPath).AwaitTermination();
// Write [5,6,7,8,9] to the source and trigger another stream batch.
_spark.Range(5, 10).Write().Format("delta").Mode("append").Save(sourcePath);
dataStreamWriter.Trigger(Trigger.Once()).Start(sinkPath).AwaitTermination();
// Finally, validate that the new data made its way to the sink.
ValidateRangeDataFrame(Enumerable.Range(0, 10), sink.ToDF());
}
// Finally, validate that the new data made its way to the sink.
ValidateRangeDataFrame(Enumerable.Range(0, 10), sink.ToDF());
}
/// <summary>
@ -150,21 +146,19 @@ namespace Microsoft.Spark.Extensions.Delta.E2ETest
[SkipIfSparkVersionIsLessThan(Versions.V2_4_2)]
public void TestIsDeltaTable()
{
using (var tempDirectory = new TemporaryDirectory())
{
// Save the same data to a DeltaTable and to Parquet.
DataFrame data = _spark.Range(0, 5);
string parquetPath = Path.Combine(tempDirectory.Path, "parquet-data");
data.Write().Parquet(parquetPath);
string deltaTablePath = Path.Combine(tempDirectory.Path, "delta-table");
data.Write().Format("delta").Save(deltaTablePath);
using var tempDirectory = new TemporaryDirectory();
// Save the same data to a DeltaTable and to Parquet.
DataFrame data = _spark.Range(0, 5);
string parquetPath = Path.Combine(tempDirectory.Path, "parquet-data");
data.Write().Parquet(parquetPath);
string deltaTablePath = Path.Combine(tempDirectory.Path, "delta-table");
data.Write().Format("delta").Save(deltaTablePath);
Assert.False(DeltaTable.IsDeltaTable(parquetPath));
Assert.False(DeltaTable.IsDeltaTable(_spark, parquetPath));
Assert.False(DeltaTable.IsDeltaTable(parquetPath));
Assert.False(DeltaTable.IsDeltaTable(_spark, parquetPath));
Assert.True(DeltaTable.IsDeltaTable(deltaTablePath));
Assert.True(DeltaTable.IsDeltaTable(_spark, deltaTablePath));
}
Assert.True(DeltaTable.IsDeltaTable(deltaTablePath));
Assert.True(DeltaTable.IsDeltaTable(_spark, deltaTablePath));
}
/// <summary>
@ -184,26 +178,24 @@ namespace Microsoft.Spark.Extensions.Delta.E2ETest
Func<string, DeltaTable> convertToDelta,
string partitionColumn = null)
{
using (var tempDirectory = new TemporaryDirectory())
using var tempDirectory = new TemporaryDirectory();
string path = Path.Combine(tempDirectory.Path, "parquet-data");
DataFrameWriter dataWriter = dataFrame.Write();
if (!string.IsNullOrEmpty(partitionColumn))
{
string path = Path.Combine(tempDirectory.Path, "parquet-data");
DataFrameWriter dataWriter = dataFrame.Write();
if (!string.IsNullOrEmpty(partitionColumn))
{
dataWriter = dataWriter.PartitionBy(partitionColumn);
}
dataWriter.Parquet(path);
Assert.False(DeltaTable.IsDeltaTable(path));
string identifier = $"parquet.`{path}`";
DeltaTable convertedDeltaTable = convertToDelta(identifier);
ValidateRangeDataFrame(Enumerable.Range(0, 5), convertedDeltaTable.ToDF());
Assert.True(DeltaTable.IsDeltaTable(path));
dataWriter = dataWriter.PartitionBy(partitionColumn);
}
dataWriter.Parquet(path);
Assert.False(DeltaTable.IsDeltaTable(path));
string identifier = $"parquet.`{path}`";
DeltaTable convertedDeltaTable = convertToDelta(identifier);
ValidateRangeDataFrame(Enumerable.Range(0, 5), convertedDeltaTable.ToDF());
Assert.True(DeltaTable.IsDeltaTable(path));
}
testWrapper(data, identifier => DeltaTable.ConvertToDelta(_spark, identifier));
@ -223,97 +215,95 @@ namespace Microsoft.Spark.Extensions.Delta.E2ETest
[SkipIfSparkVersionIsLessThan(Versions.V2_4_2)]
public void TestSignatures()
{
using (var tempDirectory = new TemporaryDirectory())
{
string path = Path.Combine(tempDirectory.Path, "delta-table");
using var tempDirectory = new TemporaryDirectory();
string path = Path.Combine(tempDirectory.Path, "delta-table");
DataFrame rangeRate = _spark.Range(15);
rangeRate.Write().Format("delta").Save(path);
DataFrame rangeRate = _spark.Range(15);
rangeRate.Write().Format("delta").Save(path);
DeltaTable table = Assert.IsType<DeltaTable>(DeltaTable.ForPath(path));
table = Assert.IsType<DeltaTable>(DeltaTable.ForPath(_spark, path));
DeltaTable table = Assert.IsType<DeltaTable>(DeltaTable.ForPath(path));
table = Assert.IsType<DeltaTable>(DeltaTable.ForPath(_spark, path));
Assert.IsType<bool>(DeltaTable.IsDeltaTable(_spark, path));
Assert.IsType<bool>(DeltaTable.IsDeltaTable(path));
Assert.IsType<bool>(DeltaTable.IsDeltaTable(_spark, path));
Assert.IsType<bool>(DeltaTable.IsDeltaTable(path));
Assert.IsType<DeltaTable>(table.As("oldTable"));
Assert.IsType<DeltaTable>(table.Alias("oldTable"));
Assert.IsType<DataFrame>(table.History());
Assert.IsType<DataFrame>(table.History(200));
Assert.IsType<DataFrame>(table.ToDF());
Assert.IsType<DeltaTable>(table.As("oldTable"));
Assert.IsType<DeltaTable>(table.Alias("oldTable"));
Assert.IsType<DataFrame>(table.History());
Assert.IsType<DataFrame>(table.History(200));
Assert.IsType<DataFrame>(table.ToDF());
DataFrame newTable = _spark.Range(10, 15).As("newTable");
Assert.IsType<DeltaMergeBuilder>(
table.Merge(newTable, Functions.Exp("oldTable.id == newTable.id")));
DeltaMergeBuilder mergeBuilder = Assert.IsType<DeltaMergeBuilder>(
table.Merge(newTable, "oldTable.id == newTable.id"));
DataFrame newTable = _spark.Range(10, 15).As("newTable");
Assert.IsType<DeltaMergeBuilder>(
table.Merge(newTable, Functions.Exp("oldTable.id == newTable.id")));
DeltaMergeBuilder mergeBuilder = Assert.IsType<DeltaMergeBuilder>(
table.Merge(newTable, "oldTable.id == newTable.id"));
// Validate the MergeBuilder matched signatures.
Assert.IsType<DeltaMergeMatchedActionBuilder>(mergeBuilder.WhenMatched());
Assert.IsType<DeltaMergeMatchedActionBuilder>(mergeBuilder.WhenMatched("id = 5"));
DeltaMergeMatchedActionBuilder matchedActionBuilder =
Assert.IsType<DeltaMergeMatchedActionBuilder>(
mergeBuilder.WhenMatched(Functions.Expr("id = 5")));
// Validate the MergeBuilder matched signatures.
Assert.IsType<DeltaMergeMatchedActionBuilder>(mergeBuilder.WhenMatched());
Assert.IsType<DeltaMergeMatchedActionBuilder>(mergeBuilder.WhenMatched("id = 5"));
DeltaMergeMatchedActionBuilder matchedActionBuilder =
Assert.IsType<DeltaMergeMatchedActionBuilder>(
mergeBuilder.WhenMatched(Functions.Expr("id = 5")));
Assert.IsType<DeltaMergeBuilder>(
matchedActionBuilder.Update(new Dictionary<string, Column>()));
Assert.IsType<DeltaMergeBuilder>(
matchedActionBuilder.UpdateExpr(new Dictionary<string, string>()));
Assert.IsType<DeltaMergeBuilder>(matchedActionBuilder.UpdateAll());
Assert.IsType<DeltaMergeBuilder>(matchedActionBuilder.Delete());
Assert.IsType<DeltaMergeBuilder>(
matchedActionBuilder.Update(new Dictionary<string, Column>()));
Assert.IsType<DeltaMergeBuilder>(
matchedActionBuilder.UpdateExpr(new Dictionary<string, string>()));
Assert.IsType<DeltaMergeBuilder>(matchedActionBuilder.UpdateAll());
Assert.IsType<DeltaMergeBuilder>(matchedActionBuilder.Delete());
// Validate the MergeBuilder not-matched signatures.
Assert.IsType<DeltaMergeNotMatchedActionBuilder>(mergeBuilder.WhenNotMatched());
// Validate the MergeBuilder not-matched signatures.
Assert.IsType<DeltaMergeNotMatchedActionBuilder>(mergeBuilder.WhenNotMatched());
Assert.IsType<DeltaMergeNotMatchedActionBuilder>(
mergeBuilder.WhenNotMatched("id = 5"));
DeltaMergeNotMatchedActionBuilder notMatchedActionBuilder =
Assert.IsType<DeltaMergeNotMatchedActionBuilder>(
mergeBuilder.WhenNotMatched("id = 5"));
DeltaMergeNotMatchedActionBuilder notMatchedActionBuilder =
Assert.IsType<DeltaMergeNotMatchedActionBuilder>(
mergeBuilder.WhenNotMatched(Functions.Expr("id = 5")));
mergeBuilder.WhenNotMatched(Functions.Expr("id = 5")));
Assert.IsType<DeltaMergeBuilder>(
notMatchedActionBuilder.Insert(new Dictionary<string, Column>()));
Assert.IsType<DeltaMergeBuilder>(
notMatchedActionBuilder.InsertExpr(new Dictionary<string, string>()));
Assert.IsType<DeltaMergeBuilder>(notMatchedActionBuilder.InsertAll());
Assert.IsType<DeltaMergeBuilder>(
notMatchedActionBuilder.Insert(new Dictionary<string, Column>()));
Assert.IsType<DeltaMergeBuilder>(
notMatchedActionBuilder.InsertExpr(new Dictionary<string, string>()));
Assert.IsType<DeltaMergeBuilder>(notMatchedActionBuilder.InsertAll());
// Update and UpdateExpr should return void.
table.Update(new Dictionary<string, Column>() { });
table.Update(Functions.Expr("id % 2 == 0"), new Dictionary<string, Column>() { });
table.UpdateExpr(new Dictionary<string, string>() { });
table.UpdateExpr("id % 2 == 1", new Dictionary<string, string>() { });
// Update and UpdateExpr should return void.
table.Update(new Dictionary<string, Column>() { });
table.Update(Functions.Expr("id % 2 == 0"), new Dictionary<string, Column>() { });
table.UpdateExpr(new Dictionary<string, string>() { });
table.UpdateExpr("id % 2 == 1", new Dictionary<string, string>() { });
Assert.IsType<DataFrame>(table.Vacuum());
Assert.IsType<DataFrame>(table.Vacuum(168));
Assert.IsType<DataFrame>(table.Vacuum());
Assert.IsType<DataFrame>(table.Vacuum(168));
// Delete should return void.
table.Delete("id > 10");
table.Delete(Functions.Expr("id > 5"));
table.Delete();
// Delete should return void.
table.Delete("id > 10");
table.Delete(Functions.Expr("id > 5"));
table.Delete();
// Load the table as a streaming source.
Assert.IsType<DataFrame>(_spark
.ReadStream()
.Format("delta")
.Option("path", path)
.Load());
Assert.IsType<DataFrame>(_spark.ReadStream().Format("delta").Load(path));
// Load the table as a streaming source.
Assert.IsType<DataFrame>(_spark
.ReadStream()
.Format("delta")
.Option("path", path)
.Load());
Assert.IsType<DataFrame>(_spark.ReadStream().Format("delta").Load(path));
// Create Parquet data and convert it to DeltaTables.
string parquetIdentifier = $"parquet.`{path}`";
rangeRate.Write().Mode(SaveMode.Overwrite).Parquet(path);
Assert.IsType<DeltaTable>(DeltaTable.ConvertToDelta(_spark, parquetIdentifier));
rangeRate
.Select(Functions.Col("id"), Functions.Expr($"(`id` + 1) AS `id_plus_one`"))
.Write()
.PartitionBy("id")
.Mode(SaveMode.Overwrite)
.Parquet(path);
Assert.IsType<DeltaTable>(DeltaTable.ConvertToDelta(
_spark,
parquetIdentifier,
"id bigint"));
// TODO: Test with StructType partition schema once StructType is supported.
}
// Create Parquet data and convert it to DeltaTables.
string parquetIdentifier = $"parquet.`{path}`";
rangeRate.Write().Mode(SaveMode.Overwrite).Parquet(path);
Assert.IsType<DeltaTable>(DeltaTable.ConvertToDelta(_spark, parquetIdentifier));
rangeRate
.Select(Functions.Col("id"), Functions.Expr($"(`id` + 1) AS `id_plus_one`"))
.Write()
.PartitionBy("id")
.Mode(SaveMode.Overwrite)
.Parquet(path);
Assert.IsType<DeltaTable>(DeltaTable.ConvertToDelta(
_spark,
parquetIdentifier,
"id bigint"));
// TODO: Test with StructType partition schema once StructType is supported.
}
/// <summary>

Просмотреть файл

@ -44,59 +44,57 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;
// Write test data to the input stream.
var pickler = new Pickler();
for (int i = 0; i < numRows; ++i)
{
int numRows = 10;
// Write test data to the input stream.
var pickler = new Pickler();
for (int i = 0; i < numRows; ++i)
{
var pickled = pickler.dumps(
new[] { new object[] { (i % 2 == 0) ? null : i.ToString() } });
SerDe.Write(inputStream, pickled.Length);
SerDe.Write(inputStream, pickled);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(10, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var unpickler = new Unpickler();
// One row was written as a batch above, thus need to read 'numRows' batches.
List<object> rows = new List<object>();
for (int i = 0; i < numRows; ++i)
{
int length = SerDe.ReadInt32(outputStream);
byte[] pickledBytes = SerDe.ReadBytes(outputStream, length);
rows.Add((unpickler.loads(pickledBytes) as ArrayList)[0] as object);
}
Assert.Equal(numRows, rows.Count);
// Validate the single command.
for (int i = 0; i < numRows; ++i)
{
Assert.Equal(
"udf: " + ((i % 2 == 0) ? "NULL" : i.ToString()),
(string)rows[i]);
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
var pickled = pickler.dumps(
new[] { new object[] { (i % 2 == 0) ? null : i.ToString() } });
SerDe.Write(inputStream, pickled.Length);
SerDe.Write(inputStream, pickled);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(10, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var unpickler = new Unpickler();
// One row was written as a batch above, thus need to read 'numRows' batches.
List<object> rows = new List<object>();
for (int i = 0; i < numRows; ++i)
{
int length = SerDe.ReadInt32(outputStream);
byte[] pickledBytes = SerDe.ReadBytes(outputStream, length);
rows.Add((unpickler.loads(pickledBytes) as ArrayList)[0] as object);
}
Assert.Equal(numRows, rows.Count);
// Validate the single command.
for (int i = 0; i < numRows; ++i)
{
Assert.Equal(
"udf: " + ((i % 2 == 0) ? "NULL" : i.ToString()),
(string)rows[i]);
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
[Fact]
@ -130,59 +128,57 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command1, command2 }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;
// Write test data to the input stream.
var pickler = new Pickler();
for (int i = 0; i < numRows; ++i)
{
int numRows = 10;
// Write test data to the input stream.
var pickler = new Pickler();
for (int i = 0; i < numRows; ++i)
{
byte[] pickled = pickler.dumps(
new[] { new object[] { i.ToString(), i, i } });
SerDe.Write(inputStream, pickled.Length);
SerDe.Write(inputStream, pickled);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(10, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var unpickler = new Unpickler();
// One row was written as a batch above, thus need to read 'numRows' batches.
List<object[]> rows = new List<object[]>();
for (int i = 0; i < numRows; ++i)
{
int length = SerDe.ReadInt32(outputStream);
byte[] pickledBytes = SerDe.ReadBytes(outputStream, length);
rows.Add((unpickler.loads(pickledBytes) as ArrayList)[0] as object[]);
}
Assert.Equal(numRows, rows.Count);
for (int i = 0; i < numRows; ++i)
{
// There were two UDFs each of which produces one column.
object[] columns = rows[i];
Assert.Equal($"udf: {i}", (string)columns[0]);
Assert.Equal(i * i, (int)columns[1]);
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
byte[] pickled = pickler.dumps(
new[] { new object[] { i.ToString(), i, i } });
SerDe.Write(inputStream, pickled.Length);
SerDe.Write(inputStream, pickled);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(10, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var unpickler = new Unpickler();
// One row was written as a batch above, thus need to read 'numRows' batches.
List<object[]> rows = new List<object[]>();
for (int i = 0; i < numRows; ++i)
{
int length = SerDe.ReadInt32(outputStream);
byte[] pickledBytes = SerDe.ReadBytes(outputStream, length);
rows.Add((unpickler.loads(pickledBytes) as ArrayList)[0] as object[]);
}
Assert.Equal(numRows, rows.Count);
for (int i = 0; i < numRows; ++i)
{
// There were two UDFs each of which produces one column.
object[] columns = rows[i];
Assert.Equal($"udf: {i}", (string)columns[0]);
Assert.Equal(i * i, (int)columns[1]);
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
[Fact]
@ -204,27 +200,25 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
{
// Write test data to the input stream. For the empty input scenario,
// only send SpecialLengths.END_OF_DATA_SECTION.
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
// Write test data to the input stream. For the empty input scenario,
// only send SpecialLengths.END_OF_DATA_SECTION.
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(0, stat.NumEntriesProcessed);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(0, stat.NumEntriesProcessed);
// Validate the output stream.
Assert.Equal(0, outputStream.Length);
}
// Validate the output stream.
Assert.Equal(0, outputStream.Length);
}
[Fact]
@ -251,62 +245,60 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray())
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Single(outputBatch.Arrays);
var array = (StringArray)outputBatch.Arrays.ElementAt(0);
// Validate the single command.
for (int i = 0; i < numRows; ++i)
{
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray())
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Single(outputBatch.Arrays);
var array = (StringArray)outputBatch.Arrays.ElementAt(0);
// Validate the single command.
for (int i = 0; i < numRows; ++i)
{
Assert.Equal($"udf: {i}", array.GetString(i));
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
Assert.Equal($"udf: {i}", array.GetString(i));
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
[Fact]
@ -347,67 +339,65 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command1, command2 }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Field(b => b.Name("arg2").DataType(Int32Type.Default))
.Field(b => b.Name("arg3").DataType(Int32Type.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray()),
ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Equal(2, outputBatch.Arrays.Count());
var array1 = (StringArray)outputBatch.Arrays.ElementAt(0);
var array2 = (Int32Array)outputBatch.Arrays.ElementAt(1);
for (int i = 0; i < numRows; ++i)
{
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Field(b => b.Name("arg2").DataType(Int32Type.Default))
.Field(b => b.Name("arg3").DataType(Int32Type.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray()),
ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
ToArrowArray(Enumerable.Range(0, numRows).ToArray()),
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
var arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Equal(2, outputBatch.Arrays.Count());
var array1 = (StringArray)outputBatch.Arrays.ElementAt(0);
var array2 = (Int32Array)outputBatch.Arrays.ElementAt(1);
for (int i = 0; i < numRows; ++i)
{
Assert.Equal($"udf: {i}", array1.GetString(i));
Assert.Equal(i * i, array2.Values[i]);
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
Assert.Equal($"udf: {i}", array1.GetString(i));
Assert.Equal(i * i, array2.Values[i]);
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
/// <summary>
@ -439,62 +429,60 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
{
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
// The .NET ArrowStreamWriter doesn't currently support writing just a
// schema with no batches - but Java does. We use Reflection to simulate
// the request Spark sends.
MethodInfo writeSchemaMethod = arrowWriter.GetType().GetMethod(
"WriteSchemaAsync",
BindingFlags.NonPublic | BindingFlags.Instance);
// The .NET ArrowStreamWriter doesn't currently support writing just a
// schema with no batches - but Java does. We use Reflection to simulate
// the request Spark sends.
MethodInfo writeSchemaMethod = arrowWriter.GetType().GetMethod(
"WriteSchemaAsync",
BindingFlags.NonPublic | BindingFlags.Instance);
writeSchemaMethod.Invoke(
arrowWriter,
new object[] { schema, CancellationToken.None });
writeSchemaMethod.Invoke(
arrowWriter,
new object[] { schema, CancellationToken.None });
SerDe.Write(inputStream, 0);
SerDe.Write(inputStream, 0);
inputStream.Seek(0, SeekOrigin.Begin);
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(0, stat.NumEntriesProcessed);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(0, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = arrowReader.ReadNextRecordBatch();
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = arrowReader.ReadNextRecordBatch();
Assert.Equal(1, outputBatch.Schema.Fields.Count);
Assert.IsType<StringType>(outputBatch.Schema.GetFieldByIndex(0).DataType);
Assert.Equal(1, outputBatch.Schema.Fields.Count);
Assert.IsType<StringType>(outputBatch.Schema.GetFieldByIndex(0).DataType);
Assert.Equal(0, outputBatch.Length);
Assert.Single(outputBatch.Arrays);
Assert.Equal(0, outputBatch.Length);
Assert.Single(outputBatch.Arrays);
var array = (StringArray)outputBatch.Arrays.ElementAt(0);
Assert.Equal(0, array.Length);
var array = (StringArray)outputBatch.Arrays.ElementAt(0);
Assert.Equal(0, array.Length);
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
[Fact]
@ -546,79 +534,77 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Field(b => b.Name("arg2").DataType(Int64Type.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray()),
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => (long)i)
.ToArray())
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Equal(2, outputBatch.ColumnCount);
var stringArray = (StringArray)outputBatch.Column(0);
for (int i = 0; i < numRows; ++i)
{
int numRows = 10;
// Write test data to the input stream.
Schema schema = new Schema.Builder()
.Field(b => b.Name("arg1").DataType(StringType.Default))
.Field(b => b.Name("arg2").DataType(Int64Type.Default))
.Build();
var arrowWriter = new ArrowStreamWriter(inputStream, schema);
await arrowWriter.WriteRecordBatchAsync(
new RecordBatch(
schema,
new[]
{
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => i.ToString())
.ToArray()),
ToArrowArray(
Enumerable.Range(0, numRows)
.Select(i => (long)i)
.ToArray())
},
numRows));
inputStream.Seek(0, SeekOrigin.Begin);
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate that all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(numRows, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
int arrowLength = SerDe.ReadInt32(outputStream);
Assert.Equal((int)SpecialLengths.START_ARROW_STREAM, arrowLength);
var arrowReader = new ArrowStreamReader(outputStream);
RecordBatch outputBatch = await arrowReader.ReadNextRecordBatchAsync();
Assert.Equal(numRows, outputBatch.Length);
Assert.Equal(2, outputBatch.ColumnCount);
var stringArray = (StringArray)outputBatch.Column(0);
for (int i = 0; i < numRows; ++i)
{
Assert.Equal($"udf: {i}", stringArray.GetString(i));
}
var longArray = (Int64Array)outputBatch.Column(1);
for (int i = 0; i < numRows; ++i)
{
Assert.Equal(100 + i, longArray.Values[i]);
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
Assert.Equal($"udf: {i}", stringArray.GetString(i));
}
var longArray = (Int64Array)outputBatch.Column(1);
for (int i = 0; i < numRows; ++i)
{
Assert.Equal(100 + i, longArray.Values[i]);
}
int end = SerDe.ReadInt32(outputStream);
Assert.Equal(0, end);
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
[Fact]
public void TestRDDCommandExecutor()
{
int mapUdf(int a) => a + 3;
static int mapUdf(int a) => a + 3;
var command = new RDDCommand()
{
WorkerFunction = new RDD.WorkerFunction(
@ -633,57 +619,55 @@ namespace Microsoft.Spark.Worker.UnitTest
Commands = new[] { command }
};
using (var inputStream = new MemoryStream())
using (var outputStream = new MemoryStream())
using var inputStream = new MemoryStream();
using var outputStream = new MemoryStream();
// Write test data to the input stream.
var formatter = new BinaryFormatter();
var memoryStream = new MemoryStream();
var inputs = new[] { 0, 1, 2, 3, 4 };
var values = new List<byte[]>();
foreach (int input in inputs)
{
// Write test data to the input stream.
var formatter = new BinaryFormatter();
var memoryStream = new MemoryStream();
var inputs = new[] { 0, 1, 2, 3, 4 };
var values = new List<byte[]>();
foreach (int input in inputs)
{
memoryStream.Position = 0;
formatter.Serialize(memoryStream, input);
values.Add(memoryStream.ToArray());
}
foreach (byte[] value in values)
{
SerDe.Write(inputStream, value.Length);
SerDe.Write(inputStream, value);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
// Execute the command.
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(5, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
for (int i = 0; i < inputs.Length; ++i)
{
Assert.True(SerDe.ReadInt32(outputStream) > 0);
Assert.Equal(
mapUdf(i),
formatter.Deserialize(outputStream));
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
memoryStream.Position = 0;
formatter.Serialize(memoryStream, input);
values.Add(memoryStream.ToArray());
}
foreach (byte[] value in values)
{
SerDe.Write(inputStream, value.Length);
SerDe.Write(inputStream, value);
}
SerDe.Write(inputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
inputStream.Seek(0, SeekOrigin.Begin);
// Execute the command.
CommandExecutorStat stat = new CommandExecutor().Execute(
inputStream,
outputStream,
0,
commandPayload);
// Validate all the data on the stream is read.
Assert.Equal(inputStream.Length, inputStream.Position);
Assert.Equal(5, stat.NumEntriesProcessed);
// Validate the output stream.
outputStream.Seek(0, SeekOrigin.Begin);
for (int i = 0; i < inputs.Length; ++i)
{
Assert.True(SerDe.ReadInt32(outputStream) > 0);
Assert.Equal(
mapUdf(i),
formatter.Deserialize(outputStream));
}
// Validate all the data on the stream is read.
Assert.Equal(outputStream.Length, outputStream.Position);
}
}
}

Просмотреть файл

@ -22,6 +22,7 @@ namespace Microsoft.Spark.Worker.UnitTest
[InlineData(Versions.V2_3_2)]
[InlineData(Versions.V2_3_3)]
[InlineData(Versions.V2_4_0)]
[InlineData(Versions.V3_0_0)]
public void TestPayloadProcessor(string version)
{
CommandPayload commandPayload = TestData.GetDefaultCommandPayload();
@ -33,11 +34,8 @@ namespace Microsoft.Spark.Worker.UnitTest
{
payloadWriter.Write(outStream, payload, commandPayload);
using (var inputStream = new MemoryStream(outStream.ToArray()))
{
actualPayload =
new PayloadProcessor(payloadWriter.Version).Process(inputStream);
}
using var inputStream = new MemoryStream(outStream.ToArray());
actualPayload = new PayloadProcessor(payloadWriter.Version).Process(inputStream);
}
// Validate the read payload.
@ -77,11 +75,11 @@ namespace Microsoft.Spark.Worker.UnitTest
PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
Payload payload = TestData.GetDefaultPayload();
var serverListener = new DefaultSocketWrapper();
using var serverListener = new DefaultSocketWrapper();
serverListener.Listen();
var port = (serverListener.LocalEndPoint as IPEndPoint).Port;
var clientSocket = new DefaultSocketWrapper();
using var clientSocket = new DefaultSocketWrapper();
clientSocket.Connect(IPAddress.Loopback, port, null);
using (ISocketWrapper serverSocket = serverListener.Accept())

Просмотреть файл

@ -88,6 +88,43 @@ namespace Microsoft.Spark.Worker.UnitTest
}
}
/// <summary>
/// TaskContextWriter for version 3.0.*.
/// </summary>
internal sealed class TaskContextWriterV3_0_X : ITaskContextWriter
{
public void Write(Stream stream, TaskContext taskContext)
{
SerDe.Write(stream, taskContext.IsBarrier);
SerDe.Write(stream, taskContext.Port);
SerDe.Write(stream, taskContext.Secret);
SerDe.Write(stream, taskContext.StageId);
SerDe.Write(stream, taskContext.PartitionId);
SerDe.Write(stream, taskContext.AttemptNumber);
SerDe.Write(stream, taskContext.AttemptId);
SerDe.Write(stream, taskContext.Resources.Count());
foreach (TaskContext.Resource resource in taskContext.Resources)
{
SerDe.Write(stream, resource.Key);
SerDe.Write(stream, resource.Value);
SerDe.Write(stream, resource.Addresses.Count());
foreach (string address in resource.Addresses)
{
SerDe.Write(stream, address);
}
}
SerDe.Write(stream, taskContext.LocalProperties.Count);
foreach (KeyValuePair<string, string> kv in taskContext.LocalProperties)
{
SerDe.Write(stream, kv.Key);
SerDe.Write(stream, kv.Value);
}
}
}
///////////////////////////////////////////////////////////////////////////
// BroadcastVariable writer for different Spark versions.
///////////////////////////////////////////////////////////////////////////
@ -302,6 +339,12 @@ namespace Microsoft.Spark.Worker.UnitTest
new TaskContextWriterV2_4_X(),
new BroadcastVariableWriterV2_3_2(),
new CommandWriterV2_4_X());
case Versions.V3_0_0:
return new PayloadWriter(
version,
new TaskContextWriterV3_0_X(),
new BroadcastVariableWriterV2_3_2(),
new CommandWriterV2_4_X());
default:
throw new NotSupportedException($"Spark {version} is not supported.");
}

Просмотреть файл

@ -18,102 +18,100 @@ namespace Microsoft.Spark.Worker.UnitTest
[Fact]
public void TestTaskRunner()
{
using (var serverListener = new DefaultSocketWrapper())
using var serverListener = new DefaultSocketWrapper();
serverListener.Listen();
var port = (serverListener.LocalEndPoint as IPEndPoint).Port;
var clientSocket = new DefaultSocketWrapper();
clientSocket.Connect(IPAddress.Loopback, port, null);
PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
var taskRunner = new TaskRunner(0, clientSocket, false, payloadWriter.Version);
var clientTask = Task.Run(() => taskRunner.Run());
using (ISocketWrapper serverSocket = serverListener.Accept())
{
serverListener.Listen();
System.IO.Stream inputStream = serverSocket.InputStream;
System.IO.Stream outputStream = serverSocket.OutputStream;
var port = (serverListener.LocalEndPoint as IPEndPoint).Port;
var clientSocket = new DefaultSocketWrapper();
clientSocket.Connect(IPAddress.Loopback, port, null);
Payload payload = TestData.GetDefaultPayload();
CommandPayload commandPayload = TestData.GetDefaultCommandPayload();
PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
var taskRunner = new TaskRunner(0, clientSocket, false, payloadWriter.Version);
var clientTask = Task.Run(() => taskRunner.Run());
payloadWriter.Write(outputStream, payload, commandPayload);
using (ISocketWrapper serverSocket = serverListener.Accept())
// Write 10 rows to the output stream.
var pickler = new Pickler();
for (int i = 0; i < 10; ++i)
{
System.IO.Stream inputStream = serverSocket.InputStream;
System.IO.Stream outputStream = serverSocket.OutputStream;
var pickled = pickler.dumps(
new[] { new object[] { i.ToString(), i, i } });
SerDe.Write(outputStream, pickled.Length);
SerDe.Write(outputStream, pickled);
}
Payload payload = TestData.GetDefaultPayload();
CommandPayload commandPayload = TestData.GetDefaultCommandPayload();
// Signal the end of data and stream.
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_STREAM);
outputStream.Flush();
payloadWriter.Write(outputStream, payload, commandPayload);
// Now process the bytes flowing in from the client.
var timingDataReceived = false;
var exceptionThrown = false;
var rowsReceived = new List<object[]>();
// Write 10 rows to the output stream.
var pickler = new Pickler();
for (int i = 0; i < 10; ++i)
while (true)
{
var length = SerDe.ReadInt32(inputStream);
if (length > 0)
{
var pickled = pickler.dumps(
new[] { new object[] { i.ToString(), i, i } });
SerDe.Write(outputStream, pickled.Length);
SerDe.Write(outputStream, pickled);
}
// Signal the end of data and stream.
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_STREAM);
outputStream.Flush();
// Now process the bytes flowing in from the client.
var timingDataReceived = false;
var exceptionThrown = false;
var rowsReceived = new List<object[]>();
while (true)
{
var length = SerDe.ReadInt32(inputStream);
if (length > 0)
var pickledBytes = SerDe.ReadBytes(inputStream, length);
using var unpickler = new Unpickler();
var rows = unpickler.loads(pickledBytes) as ArrayList;
foreach (object row in rows)
{
var pickledBytes = SerDe.ReadBytes(inputStream, length);
var unpickler = new Unpickler();
var rows = unpickler.loads(pickledBytes) as ArrayList;
foreach (object row in rows)
{
rowsReceived.Add((object[])row);
}
}
else if (length == (int)SpecialLengths.TIMING_DATA)
{
var bootTime = SerDe.ReadInt64(inputStream);
var initTime = SerDe.ReadInt64(inputStream);
var finishTime = SerDe.ReadInt64(inputStream);
var memoryBytesSpilled = SerDe.ReadInt64(inputStream);
var diskBytesSpilled = SerDe.ReadInt64(inputStream);
timingDataReceived = true;
}
else if (length == (int)SpecialLengths.PYTHON_EXCEPTION_THROWN)
{
SerDe.ReadString(inputStream);
exceptionThrown = true;
break;
}
else if (length == (int)SpecialLengths.END_OF_DATA_SECTION)
{
var numAccumulatorUpdates = SerDe.ReadInt32(inputStream);
SerDe.ReadInt32(inputStream);
break;
rowsReceived.Add((object[])row);
}
}
Assert.True(timingDataReceived);
Assert.False(exceptionThrown);
// Validate rows received.
Assert.Equal(10, rowsReceived.Count);
for (int i = 0; i < 10; ++i)
else if (length == (int)SpecialLengths.TIMING_DATA)
{
// Two UDFs registered, thus expecting two columns.
// Refer to TestData.GetDefaultCommandPayload().
var row = rowsReceived[i];
Assert.Equal(2, rowsReceived[i].Length);
Assert.Equal($"udf2 udf1 {i}", row[0]);
Assert.Equal(i + i, row[1]);
var bootTime = SerDe.ReadInt64(inputStream);
var initTime = SerDe.ReadInt64(inputStream);
var finishTime = SerDe.ReadInt64(inputStream);
var memoryBytesSpilled = SerDe.ReadInt64(inputStream);
var diskBytesSpilled = SerDe.ReadInt64(inputStream);
timingDataReceived = true;
}
else if (length == (int)SpecialLengths.PYTHON_EXCEPTION_THROWN)
{
SerDe.ReadString(inputStream);
exceptionThrown = true;
break;
}
else if (length == (int)SpecialLengths.END_OF_DATA_SECTION)
{
var numAccumulatorUpdates = SerDe.ReadInt32(inputStream);
SerDe.ReadInt32(inputStream);
break;
}
}
Assert.True(clientTask.Wait(5000));
Assert.True(timingDataReceived);
Assert.False(exceptionThrown);
// Validate rows received.
Assert.Equal(10, rowsReceived.Count);
for (int i = 0; i < 10; ++i)
{
// Two UDFs registered, thus expecting two columns.
// Refer to TestData.GetDefaultCommandPayload().
var row = rowsReceived[i];
Assert.Equal(2, rowsReceived[i].Length);
Assert.Equal($"udf2 udf1 {i}", row[0]);
Assert.Equal(i + i, row[1]);
}
}
Assert.True(clientTask.Wait(5000));
}
}
}

Просмотреть файл

@ -27,13 +27,14 @@ namespace Microsoft.Spark.Worker
internal string Secret { get; set; }
internal IEnumerable<Resource> Resources { get; set; } = new List<Resource>();
internal Dictionary<string, string> LocalProperties { get; set; } =
new Dictionary<string, string>();
public override bool Equals(object obj)
{
var other = obj as TaskContext;
if (other is null)
if (!(obj is TaskContext other))
{
return false;
}
@ -42,6 +43,7 @@ namespace Microsoft.Spark.Worker
(PartitionId == other.PartitionId) &&
(AttemptNumber == other.AttemptNumber) &&
(AttemptId == other.AttemptId) &&
Resources.SequenceEqual(other.Resources) &&
(LocalProperties.Count == other.LocalProperties.Count) &&
!LocalProperties.Except(other.LocalProperties).Any();
}
@ -50,6 +52,30 @@ namespace Microsoft.Spark.Worker
{
return StageId;
}
internal class Resource
{
internal string Key { get; set; }
internal string Value { get; set; }
internal IEnumerable<string> Addresses { get; set; } = new List<string>();
public override bool Equals(object obj)
{
if (!(obj is Resource other))
{
return false;
}
return (Key == other.Key) &&
(Value == other.Value) &&
Addresses.SequenceEqual(Addresses);
}
public override int GetHashCode()
{
return Key.GetHashCode();
}
}
}
/// <summary>
@ -68,8 +94,7 @@ namespace Microsoft.Spark.Worker
public override bool Equals(object obj)
{
var other = obj as BroadcastVariables;
if (other is null)
if (!(obj is BroadcastVariables other))
{
return false;
}

Просмотреть файл

@ -114,18 +114,13 @@ namespace Microsoft.Spark.Worker.Processor
throw new NotImplementedException($"{evalType} is not supported.");
}
if (version.Major == 2)
return (version.Major, version.Minor) switch
{
switch (version.Minor)
{
case 3:
return SqlCommandProcessorV2_3_X.Process(evalType, stream);
case 4:
return SqlCommandProcessorV2_4_X.Process(evalType, stream);
}
}
throw new NotSupportedException($"Spark {version} not supported.");
(2, 3) => SqlCommandProcessorV2_3_X.Process(evalType, stream),
(2, 4) => SqlCommandProcessorV2_4_X.Process(evalType, stream),
(3, 0) => SqlCommandProcessorV2_4_X.Process(evalType, stream),
_ => throw new NotSupportedException($"Spark {version} not supported.")
};
}
/// <summary>

Просмотреть файл

@ -19,31 +19,66 @@ namespace Microsoft.Spark.Worker.Processor
internal TaskContext Process(Stream stream)
{
if (_version.Major == 2)
return (_version.Major, _version.Minor) switch
{
switch (_version.Minor)
(2, 3) => TaskContextProcessorV2_3_X.Process(stream),
(2, 4) => TaskContextProcessorV2_4_X.Process(stream),
(3, 0) => TaskContextProcessorV3_0_X.Process(stream),
_ => throw new NotSupportedException($"Spark {_version} not supported.")
};
}
private static TaskContext ReadTaskContext(Stream stream)
{
return new TaskContext
{
StageId = SerDe.ReadInt32(stream),
PartitionId = SerDe.ReadInt32(stream),
AttemptNumber = SerDe.ReadInt32(stream),
AttemptId = SerDe.ReadInt64(stream)
};
}
private static void ReadBarrierInfo(Stream stream)
{
// Read barrier-related payload. Note that barrier is currently not supported.
SerDe.ReadBool(stream); // IsBarrier
SerDe.ReadInt32(stream); // BoundPort
SerDe.ReadString(stream); // Secret
}
private static void ReadTaskContextProperties(Stream stream, TaskContext taskContext)
{
int numProperties = SerDe.ReadInt32(stream);
for (int i = 0; i < numProperties; ++i)
{
string key = SerDe.ReadString(stream);
string value = SerDe.ReadString(stream);
taskContext.LocalProperties.Add(key, value);
}
}
private static void ReadTaskContextResources(Stream stream)
{
// Currently, resources are not supported.
int numResources = SerDe.ReadInt32(stream);
for (int i = 0; i < numResources; ++i)
{
SerDe.ReadString(stream); // key
SerDe.ReadString(stream); // value
int numAddresses = SerDe.ReadInt32(stream);
for (int j = 0; j < numAddresses; ++j)
{
case 3:
return TaskContextProcessorV2_3_X.Process(stream);
case 4:
return TaskContextProcessorV2_4_X.Process(stream);
SerDe.ReadString(stream); // address
}
}
throw new NotSupportedException($"Spark {_version} not supported.");
}
private static class TaskContextProcessorV2_3_X
{
internal static TaskContext Process(Stream stream)
{
return new TaskContext
{
StageId = SerDe.ReadInt32(stream),
PartitionId = SerDe.ReadInt32(stream),
AttemptNumber = SerDe.ReadInt32(stream),
AttemptId = SerDe.ReadInt64(stream)
};
return ReadTaskContext(stream);
}
}
@ -51,26 +86,22 @@ namespace Microsoft.Spark.Worker.Processor
{
internal static TaskContext Process(Stream stream)
{
// Read barrier-related payload. Note that barrier is currently not supported.
SerDe.ReadBool(stream); // IsBarrier
SerDe.ReadInt32(stream); // BoundPort
SerDe.ReadString(stream); // Secret
ReadBarrierInfo(stream);
TaskContext taskContext = ReadTaskContext(stream);
ReadTaskContextProperties(stream, taskContext);
var taskContext = new TaskContext
{
StageId = SerDe.ReadInt32(stream),
PartitionId = SerDe.ReadInt32(stream),
AttemptNumber = SerDe.ReadInt32(stream),
AttemptId = SerDe.ReadInt64(stream)
};
return taskContext;
}
}
int numProperties = SerDe.ReadInt32(stream);
for (int i = 0; i < numProperties; ++i)
{
string key = SerDe.ReadString(stream);
string value = SerDe.ReadString(stream);
taskContext.LocalProperties.Add(key, value);
}
private static class TaskContextProcessorV3_0_X
{
internal static TaskContext Process(Stream stream)
{
ReadBarrierInfo(stream);
TaskContext taskContext = ReadTaskContext(stream);
ReadTaskContextResources(stream);
ReadTaskContextProperties(stream, taskContext);
return taskContext;
}

Просмотреть файл

@ -877,7 +877,7 @@ namespace Microsoft.Spark.Sql
return true;
}
return obj is Column other && this._jvmObject.Equals(other._jvmObject);
return obj is Column other && _jvmObject.Equals(other._jvmObject);
}
/// <summary>

Просмотреть файл

@ -12,6 +12,7 @@ namespace Microsoft.Spark
internal const string V2_3_3 = "2.3.3";
internal const string V2_4_0 = "2.4.0";
internal const string V2_4_2 = "2.4.2";
internal const string V3_0_0 = "3.0.0";
// The following is used to check the compatibility of UDFs between
// the driver and worker side. This needs to be updated only when there