Follow up on UDF that takes in and returns Row (#406)

This commit is contained in:
elvaliuliuliu 2020-01-29 09:06:55 -08:00 коммит произвёл GitHub
Родитель 96d0fedb13
Коммит a8db9853f8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 101 добавлений и 52 удалений

Просмотреть файл

@ -1,3 +1,3 @@
{"name":"Michael", "ids":[1], "info":{"city":"Burdwan", "state":"Paschimbanga"}}
{"name":"Andy", "age":30, "ids":[3,5], "info":{"city":"Los Angeles", "state":"California"}}
{"name":"Justin", "age":19, "ids":[2,4], "info":{"city":"Seattle"}}
{"name":"Michael", "ids":[1], "info1":{"city":"Burdwan"}, "info2":{"state":"Paschimbanga"}, "info3":{"company":{"job":"Developer"}}}"
{"name":"Andy", "age":30, "ids":[3,5], "info1":{"city":"Los Angeles"}, "info2":{"state":"California"}, "info3":{"company":{"job":"Developer"}}}
{"name":"Justin", "age":19, "ids":[2,4], "info1":{"city":"Seattle"}, "info2":{"state":"Washington"}, "info3":{"company":{"job":"Developer"}}}

Просмотреть файл

@ -136,20 +136,59 @@ namespace Microsoft.Spark.E2ETest.UdfTests
[Fact]
public void TestUdfWithRowType()
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
string city = row.GetAs<string>("city");
string state = row.GetAs<string>("state");
return $"{city},{state}";
});
// Single Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) => row.GetAs<string>("city"));
Row[] rows = _df.Select(udf(_df["info"])).Collect().ToArray();
Assert.Equal(3, rows.Length);
Row[] rows = _df.Select(udf(_df["info1"])).Collect().ToArray();
Assert.Equal(3, rows.Length);
var expected = new[] { "Burdwan,Paschimbanga", "Los Angeles,California", "Seattle," };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
var expected = new[] { "Burdwan", "Los Angeles", "Seattle" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
// Multiple Rows
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
{
string city = row1.GetAs<string>("city");
string state = row2.GetAs<string>("state");
return $"{str}:{city},{state}";
});
Row[] rows = _df
.Select(udf(_df["info1"], _df["info2"], _df["name"]))
.Collect()
.ToArray();
Assert.Equal(3, rows.Length);
var expected = new[] {
"Michael:Burdwan,Paschimbanga",
"Andy:Los Angeles,California",
"Justin:Seattle,Washington" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
// Nested Row
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
{
Row outerCol = row.GetAs<Row>("company");
return outerCol.GetAs<string>("job");
});
Row[] rows = _df.Select(udf(_df["info3"])).Collect().ToArray();
Assert.Equal(3, rows.Length);
var expected = new[] { "Developer", "Developer", "Developer" };
string[] actual = rows.Select(x => x[0].ToString()).ToArray();
Assert.Equal(expected, actual);
}
}
/// <summary>
@ -168,14 +207,40 @@ namespace Microsoft.Spark.E2ETest.UdfTests
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 1, "abc" }), schema);
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(2, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal("abc", outerCol.GetAs<string>("col2"));
}
}
// Generic row is a part of top-level column.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType())
});
Func<Column, Column> udf = Udf<string>(
str => new GenericRow(new object[] { 111 }), schema);
Column nameCol = _df["name"];
Row[] rows = _df.Select(udf(nameCol).As("col"), nameCol).Collect().ToArray();
Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Equal(2, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal("abc", row.GetAs<string>("col2"));
Row col1 = row.GetAs<Row>("col");
Assert.Equal(1, col1.Size());
Assert.Equal(111, col1.GetAs<int>("col1"));
string col2 = row.GetAs<string>("name");
Assert.NotEmpty(col2);
}
}
@ -211,21 +276,23 @@ namespace Microsoft.Spark.E2ETest.UdfTests
}),
schema);
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
Row[] rows = _df.Select(udf(_df["name"]).As("col")).Collect().ToArray();
Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Equal(3, row.Size());
Assert.Equal(1, row.GetAs<int>("col1"));
Assert.Equal(1, row.Size());
Row outerCol = row.GetAs<Row>("col");
Assert.Equal(3, outerCol.Size());
Assert.Equal(1, outerCol.GetAs<int>("col1"));
Assert.Equal(
new Row(new object[] { 1 }, subSchema1),
row.GetAs<Row>("col2"));
outerCol.GetAs<Row>("col2"));
Assert.Equal(
new Row(
new object[] { "abc", new Row(new object[] { 10 }, subSchema1) },
subSchema2),
row.GetAs<Row>("col3"));
outerCol.GetAs<Row>("col3"));
}
}
}

Просмотреть файл

@ -135,8 +135,17 @@ namespace Microsoft.Spark.Worker.Command
for (int i = 0; i < inputRows.Length; ++i)
{
object row = inputRows[i];
// The following can happen if an UDF takes Row object(s).
// The JVM Spark side sends a Row object that wraps all the columns used
// in the UDF, thus, it is normalized below (the extra layer is removed).
if (row is RowConstructor rowConstructor)
{
row = rowConstructor.GetRow().Values;
}
// Split id is not used for SQL UDFs, so 0 is passed.
outputRows.Add(commandRunner.Run(0, inputRows[i]));
outputRows.Add(commandRunner.Run(0, row));
}
// The initial (estimated) buffer size for pickling rows is set to the size of

Просмотреть файл

@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.Interop.Ipc;
@ -34,24 +33,7 @@ namespace Microsoft.Spark.Sql
foreach (object unpickled in unpickledObjects)
{
// Unpickled object can be either a RowConstructor object (not materialized),
// or a Row object (materialized). Refer to RowConstruct.construct() to see how
// Row objects are unpickled.
switch (unpickled)
{
case RowConstructor rc:
yield return rc.GetRow();
break;
case object[] objs when objs.Length == 1 && (objs[0] is Row row):
yield return row;
break;
default:
throw new NotSupportedException(
string.Format("Unpickle type {0} is not supported",
unpickled.GetType()));
}
yield return (unpickled as RowConstructor).GetRow();
}
}
}

Просмотреть файл

@ -65,15 +65,6 @@ namespace Microsoft.Spark.Sql
s_schemaCache = new Dictionary<string, StructType>();
}
// When a row is ready to be materialized, then construct() is called
// on the RowConstructor which represents the row.
if ((args.Length == 1) && (args[0] is RowConstructor rowConstructor))
{
// Construct the Row and return args containing the Row.
args[0] = rowConstructor.GetRow();
return args;
}
// Return a new RowConstructor where the args either represent the
// schema or the row data. The parent becomes important when calling
// GetRow() on the RowConstructor containing the row data.