Fix the chained UDF with Row type (#411)

This commit is contained in:
elvaliuliuliu 2020-02-07 19:00:23 -08:00 коммит произвёл GitHub
Родитель ee95ca25aa
Коммит 6c8aea427b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 78 добавлений и 22 удалений

Просмотреть файл

@ -136,7 +136,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
[Fact]
public void TestUdfWithRowType()
{
// Single Row
// Single Row.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) => row.GetAs<string>("city"));
@ -149,7 +149,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
Assert.Equal(expected, actual);
}
// Multiple Rows
// Multiple Rows.
{
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
(row1, row2, str) =>
@ -173,7 +173,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
Assert.Equal(expected, actual);
}
// Nested Row
// Nested Rows.
{
Func<Column, Column> udf = Udf<Row, string>(
(row) =>
@ -197,7 +197,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
[Fact]
public void TestUdfWithReturnAsRowType()
{
// Single GenericRow
// Test UDF that returns a Row object with a single column.
{
var schema = new StructType(new[]
{
@ -219,7 +219,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
}
}
// Generic row is a part of top-level column.
// Test UDF that returns a Row object with multiple columns.
{
var schema = new StructType(new[]
{
@ -244,7 +244,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
}
}
// Nested GenericRow
// Test UDF that returns a nested Row object.
{
var subSchema1 = new StructType(new[]
{
@ -295,6 +295,30 @@ namespace Microsoft.Spark.E2ETest.UdfTests
outerCol.GetAs<Row>("col3"));
}
}
// Chained UDFs.
{
var schema = new StructType(new[]
{
new StructField("col1", new IntegerType()),
new StructField("col2", new StringType())
});
Func<Column, Column> udf1 = Udf<string>(
str => new GenericRow(new object[] { 1, str }), schema);
Func<Column, Column> udf2 = Udf<Row, string>(
row => row.GetAs<string>(1));
Row[] rows = _df.Select(udf2(udf1(_df["name"]))).Collect().ToArray();
Assert.Equal(3, rows.Length);
var expected = new[] { "Michael", "Andy", "Justin" };
for (int i = 0; i < rows.Length; ++i)
{
Assert.Equal(1, rows[i].Size());
Assert.Equal(expected[i], rows[i].GetAs<string>(0));
}
}
}
}
}

Просмотреть файл

@ -136,7 +136,7 @@ namespace Microsoft.Spark.UnitTest
private Pickler CreatePickler()
{
new StructTypePickler().Register();
new RowPickler().Register();
new TestUtils.RowPickler().Register();
return new Pickler();
}

Просмотреть файл

@ -7,6 +7,17 @@ using Razorvine.Pickle;
namespace Microsoft.Spark.Sql
{
/// <summary>
/// Custom pickler for Row objects.
/// </summary>
internal class RowPickler : IObjectPickler
{
public void pickle(object o, Stream outs, Pickler currentPickler)
{
currentPickler.save(((Row)o).Values);
}
}
/// <summary>
/// Custom pickler for GenericRow objects.
/// </summary>

Просмотреть файл

@ -3804,7 +3804,7 @@ namespace Microsoft.Spark.Sql
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
public static Func<Column> Udf(Func<Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply0;
}
@ -3816,7 +3816,7 @@ namespace Microsoft.Spark.Sql
/// <returns>
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType returnType)
public static Func<Column, Column> Udf<T>(Func<T, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply1;
}
@ -3830,7 +3830,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column> Udf<T1, T2>(
Func<T1, T2, GenericRow> udf, StructType returnType)
Func<T1, T2, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply2;
}
@ -3845,7 +3845,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
Func<T1, T2, T3, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply3;
}
@ -3861,7 +3861,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
Func<T1, T2, T3, T4, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply4;
}
@ -3878,7 +3878,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5>(
Func<T1, T2, T3, T4, T5, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply5;
}
@ -3896,7 +3896,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6>(
Func<T1, T2, T3, T4, T5, T6, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply6;
}
@ -3915,7 +3915,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7>(
Func<T1, T2, T3, T4, T5, T6, T7, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply7;
}
@ -3935,7 +3935,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply8;
}
@ -3954,7 +3954,7 @@ namespace Microsoft.Spark.Sql
/// <param name="returnType">Schema associated with this row</param>
/// <returns>A delegate that when invoked will return a <see cref="Column"/> for the result of the UDF.</returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply9;
}
@ -3976,7 +3976,7 @@ namespace Microsoft.Spark.Sql
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
/// </returns>
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, GenericRow> udf, StructType returnType)
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, Row> udf, StructType returnType)
{
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply10;
}

Просмотреть файл

@ -98,6 +98,6 @@ namespace Microsoft.Spark.Sql
/// Returns the hash code of the current object.
/// </summary>
/// <returns>The hash code of the current object</returns>
public override int GetHashCode() => base.GetHashCode();
public override int GetHashCode() => base.GetHashCode();
}
}

Просмотреть файл

@ -3,8 +3,6 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.Sql.Types;
namespace Microsoft.Spark.Sql
@ -36,6 +34,28 @@ namespace Microsoft.Spark.Sql
Convert();
}
/// <summary>
/// Constructor for the schema-less Row class used for chained UDFs.
/// </summary>
/// <param name="genericRow">GenericRow object</param>
internal Row(GenericRow genericRow)
{
_genericRow = genericRow;
}
/// <summary>
/// Returns schema-less Row which can happen within chained UDFs (same behavior as PySpark).
/// </summary>
/// <remarks>
/// The use of this conversion operator is discouraged except for the UDF that returns
/// a Row object.
/// </remarks>
/// <returns>schema-less Row</returns>
public static implicit operator Row(GenericRow genericRow)
{
return new Row(genericRow);
}
/// <summary>
/// Schema associated with this row.
/// </summary>

Просмотреть файл

@ -36,7 +36,8 @@ namespace Microsoft.Spark.Utils
Unpickler.registerConstructor(
"pyspark.sql.types", "_create_row_inbound_converter", s_rowConstructor);
// Register custom pickler for GenericRow objects.
// Register custom picklers.
Pickler.registerCustomPickler(typeof(Row), new RowPickler());
Pickler.registerCustomPickler(typeof(GenericRow), new GenericRowPickler());
}