зеркало из https://github.com/dotnet/spark.git
Fix the chained UDF with Row type (#411)
This commit is contained in:
Родитель
ee95ca25aa
Коммит
6c8aea427b
|
@ -136,7 +136,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
[Fact]
|
||||
public void TestUdfWithRowType()
|
||||
{
|
||||
// Single Row
|
||||
// Single Row.
|
||||
{
|
||||
Func<Column, Column> udf = Udf<Row, string>(
|
||||
(row) => row.GetAs<string>("city"));
|
||||
|
@ -149,7 +149,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
Assert.Equal(expected, actual);
|
||||
}
|
||||
|
||||
// Multiple Rows
|
||||
// Multiple Rows.
|
||||
{
|
||||
Func<Column, Column, Column, Column> udf = Udf<Row, Row, string, string>(
|
||||
(row1, row2, str) =>
|
||||
|
@ -173,7 +173,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
Assert.Equal(expected, actual);
|
||||
}
|
||||
|
||||
// Nested Row
|
||||
// Nested Rows.
|
||||
{
|
||||
Func<Column, Column> udf = Udf<Row, string>(
|
||||
(row) =>
|
||||
|
@ -197,7 +197,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
[Fact]
|
||||
public void TestUdfWithReturnAsRowType()
|
||||
{
|
||||
// Single GenericRow
|
||||
// Test UDF that returns a Row object with a single column.
|
||||
{
|
||||
var schema = new StructType(new[]
|
||||
{
|
||||
|
@ -219,7 +219,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
}
|
||||
}
|
||||
|
||||
// Generic row is a part of top-level column.
|
||||
// Test UDF that returns a Row object with multiple columns.
|
||||
{
|
||||
var schema = new StructType(new[]
|
||||
{
|
||||
|
@ -244,7 +244,7 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
}
|
||||
}
|
||||
|
||||
// Nested GenericRow
|
||||
// Test UDF that returns a nested Row object.
|
||||
{
|
||||
var subSchema1 = new StructType(new[]
|
||||
{
|
||||
|
@ -295,6 +295,30 @@ namespace Microsoft.Spark.E2ETest.UdfTests
|
|||
outerCol.GetAs<Row>("col3"));
|
||||
}
|
||||
}
|
||||
|
||||
// Chained UDFs.
|
||||
{
|
||||
var schema = new StructType(new[]
|
||||
{
|
||||
new StructField("col1", new IntegerType()),
|
||||
new StructField("col2", new StringType())
|
||||
});
|
||||
Func<Column, Column> udf1 = Udf<string>(
|
||||
str => new GenericRow(new object[] { 1, str }), schema);
|
||||
|
||||
Func<Column, Column> udf2 = Udf<Row, string>(
|
||||
row => row.GetAs<string>(1));
|
||||
|
||||
Row[] rows = _df.Select(udf2(udf1(_df["name"]))).Collect().ToArray();
|
||||
Assert.Equal(3, rows.Length);
|
||||
|
||||
var expected = new[] { "Michael", "Andy", "Justin" };
|
||||
for (int i = 0; i < rows.Length; ++i)
|
||||
{
|
||||
Assert.Equal(1, rows[i].Size());
|
||||
Assert.Equal(expected[i], rows[i].GetAs<string>(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -136,7 +136,7 @@ namespace Microsoft.Spark.UnitTest
|
|||
private Pickler CreatePickler()
|
||||
{
|
||||
new StructTypePickler().Register();
|
||||
new RowPickler().Register();
|
||||
new TestUtils.RowPickler().Register();
|
||||
return new Pickler();
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,17 @@ using Razorvine.Pickle;
|
|||
|
||||
namespace Microsoft.Spark.Sql
|
||||
{
|
||||
/// <summary>
|
||||
/// Custom pickler for Row objects.
|
||||
/// </summary>
|
||||
internal class RowPickler : IObjectPickler
|
||||
{
|
||||
public void pickle(object o, Stream outs, Pickler currentPickler)
|
||||
{
|
||||
currentPickler.save(((Row)o).Values);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Custom pickler for GenericRow objects.
|
||||
/// </summary>
|
|
@ -3804,7 +3804,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// <returns>
|
||||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
|
||||
public static Func<Column> Udf(Func<Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply0;
|
||||
}
|
||||
|
@ -3816,7 +3816,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// <returns>
|
||||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType returnType)
|
||||
public static Func<Column, Column> Udf<T>(Func<T, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply1;
|
||||
}
|
||||
|
@ -3830,7 +3830,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column> Udf<T1, T2>(
|
||||
Func<T1, T2, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply2;
|
||||
}
|
||||
|
@ -3845,7 +3845,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
|
||||
Func<T1, T2, T3, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply3;
|
||||
}
|
||||
|
@ -3861,7 +3861,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
|
||||
Func<T1, T2, T3, T4, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply4;
|
||||
}
|
||||
|
@ -3878,7 +3878,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5>(
|
||||
Func<T1, T2, T3, T4, T5, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply5;
|
||||
}
|
||||
|
@ -3896,7 +3896,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6>(
|
||||
Func<T1, T2, T3, T4, T5, T6, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, T6, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply6;
|
||||
}
|
||||
|
@ -3915,7 +3915,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7>(
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply7;
|
||||
}
|
||||
|
@ -3935,7 +3935,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8>(
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply8;
|
||||
}
|
||||
|
@ -3954,7 +3954,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// <param name="returnType">Schema associated with this row</param>
|
||||
/// <returns>A delegate that when invoked will return a <see cref="Column"/> for the result of the UDF.</returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9>(
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply9;
|
||||
}
|
||||
|
@ -3976,7 +3976,7 @@ namespace Microsoft.Spark.Sql
|
|||
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
|
||||
/// </returns>
|
||||
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, GenericRow> udf, StructType returnType)
|
||||
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, Row> udf, StructType returnType)
|
||||
{
|
||||
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply10;
|
||||
}
|
||||
|
|
|
@ -98,6 +98,6 @@ namespace Microsoft.Spark.Sql
|
|||
/// Returns the hash code of the current object.
|
||||
/// </summary>
|
||||
/// <returns>The hash code of the current object</returns>
|
||||
public override int GetHashCode() => base.GetHashCode();
|
||||
public override int GetHashCode() => base.GetHashCode();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using Microsoft.Spark.Sql.Types;
|
||||
|
||||
namespace Microsoft.Spark.Sql
|
||||
|
@ -36,6 +34,28 @@ namespace Microsoft.Spark.Sql
|
|||
Convert();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructor for the schema-less Row class used for chained UDFs.
|
||||
/// </summary>
|
||||
/// <param name="genericRow">GenericRow object</param>
|
||||
internal Row(GenericRow genericRow)
|
||||
{
|
||||
_genericRow = genericRow;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns schema-less Row which can happen within chained UDFs (same behavior as PySpark).
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The use of this conversion operator is discouraged except for the UDF that returns
|
||||
/// a Row object.
|
||||
/// </remarks>
|
||||
/// <returns>schema-less Row</returns>
|
||||
public static implicit operator Row(GenericRow genericRow)
|
||||
{
|
||||
return new Row(genericRow);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Schema associated with this row.
|
||||
/// </summary>
|
||||
|
|
|
@ -36,7 +36,8 @@ namespace Microsoft.Spark.Utils
|
|||
Unpickler.registerConstructor(
|
||||
"pyspark.sql.types", "_create_row_inbound_converter", s_rowConstructor);
|
||||
|
||||
// Register custom pickler for GenericRow objects.
|
||||
// Register custom picklers.
|
||||
Pickler.registerCustomPickler(typeof(Row), new RowPickler());
|
||||
Pickler.registerCustomPickler(typeof(GenericRow), new GenericRowPickler());
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче