Merge pull request #176 from jayjaywg/master

Refactor DataFrame schema and create data type classes, add CreateDataFrame of SqlContext
This commit is contained in:
guwang 2015-12-20 22:20:41 -08:00
Родитель 333c8a34b6 ebd55c3ce8
Коммит 9b75c6fe1b
25 изменённых файлов: 751 добавлений и 460 удалений

Просмотреть файл

@ -76,6 +76,7 @@
<Compile Include="Core\StatCounter.cs" />
<Compile Include="Core\StatusTracker.cs" />
<Compile Include="Core\StorageLevel.cs" />
<Compile Include="Interop\Ipc\JsonSerDe.cs" />
<Compile Include="Interop\SparkCLREnvironment.cs" />
<Compile Include="Interop\Ipc\IJvmBridge.cs" />
<Compile Include="Interop\Ipc\JvmBridge.cs" />
@ -118,7 +119,7 @@
<Compile Include="Sql\Functions.cs" />
<Compile Include="Sql\SaveMode.cs" />
<Compile Include="Sql\SqlContext.cs" />
<Compile Include="Sql\Struct.cs" />
<Compile Include="Sql\Types.cs" />
<Compile Include="Sql\UserDefinedFunction.cs" />
<Compile Include="Streaming\DStream.cs" />
<Compile Include="Streaming\PairDStreamFunctions.cs" />

Просмотреть файл

@ -283,7 +283,7 @@ namespace Microsoft.Spark.CSharp.Core
/// <returns></returns>
public RDD<T> Distinct(int numPartitions = 0)
{
return Map(x => new KeyValuePair<T, int>(x, 0)).ReduceByKey((x,y) => x, numPartitions).Map<T>(x => x.Key);
return Map(x => new KeyValuePair<T, int>(x, 0)).ReduceByKey((x, y) => x, numPartitions).Map<T>(x => x.Key);
}
/// <summary>
@ -417,7 +417,7 @@ namespace Microsoft.Spark.CSharp.Core
else
{
const double delta = 0.00005;
var gamma = - Math.Log(delta) / total;
var gamma = -Math.Log(delta) / total;
return Math.Min(1, fraction + gamma + Math.Sqrt(gamma * gamma + 2 * gamma * fraction));
}
}
@ -811,6 +811,7 @@ namespace Microsoft.Spark.CSharp.Core
int left = num - items.Count;
IEnumerable<int> partitions = Enumerable.Range(partsScanned, Math.Min(numPartsToTry, totalParts - partsScanned));
var mappedRDD = MapPartitionsWithIndex<T>(new TakeHelper<T>(left).Execute);
int port = sparkContext.SparkContextProxy.RunJob(mappedRDD.RddProxy, partitions);
@ -867,7 +868,7 @@ namespace Microsoft.Spark.CSharp.Core
{
return Map<KeyValuePair<T, T>>(v => new KeyValuePair<T, T>(v, default(T))).SubtractByKey
(
other.Map<KeyValuePair<T, T>>(v => new KeyValuePair<T, T>(v, default(T))),
other.Map<KeyValuePair<T, T>>(v => new KeyValuePair<T, T>(v, default(T))),
numPartitions
)
.Keys();
@ -1044,7 +1045,7 @@ namespace Microsoft.Spark.CSharp.Core
/// <returns></returns>
public IEnumerable<T> ToLocalIterator()
{
foreach(int partition in Enumerable.Range(0, GetNumPartitions()))
foreach (int partition in Enumerable.Range(0, GetNumPartitions()))
{
var mappedRDD = MapPartitionsWithIndex<T>((pid, iter) => iter);
int port = sparkContext.SparkContextProxy.RunJob(mappedRDD.RddProxy, Enumerable.Range(partition, 1));
@ -1382,7 +1383,7 @@ namespace Microsoft.Spark.CSharp.Core
internal KeyValuePair<K, T> Execute(T input)
{
return new KeyValuePair<K,T>(func(input), input);
return new KeyValuePair<K, T>(func(input), input);
}
}
[Serializable]
@ -1429,7 +1430,7 @@ namespace Microsoft.Spark.CSharp.Core
else if (y.Value)
return x;
else
return new KeyValuePair<T,bool>(func(x.Key, y.Key), false);
return new KeyValuePair<T, bool>(func(x.Key, y.Key), false);
}
}
[Serializable]

Просмотреть файл

@ -0,0 +1,68 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System.Linq;
using Newtonsoft.Json.Linq;
namespace Microsoft.Spark.CSharp.Interop.Ipc
{
/// <summary>
/// Json.NET Serialization/Deserialization helper class.
/// </summary>
public static class JsonSerDe
{
// Note: Scala side uses JSortedObject when parse Json, so the properties in JObject need to be sorted
/// <summary>
/// Extend method to sort items in a JSON object by keys.
/// </summary>
/// <param name="jObject"></param>
/// <returns></returns>
public static JObject SortProperties(this JObject jObject)
{
JObject sortedJObject = new JObject();
foreach (var property in jObject.Properties().OrderBy(p => p.Name))
{
if (property.Value is JObject)
{
var propJObject = property.Value as JObject;
sortedJObject.Add(property.Name, propJObject.SortProperties());
}
else if (property.Value is JArray)
{
var propJArray = property.Value as JArray;
sortedJObject.Add(property.Name, propJArray.SortProperties());
}
else
{
sortedJObject.Add(property.Name, property.Value);
}
}
return sortedJObject;
}
/// <summary>
/// Extend method to sort items in a JSON array by keys.
/// </summary>
public static JArray SortProperties(this JArray jArray)
{
JArray sortedJArray = new JArray();
if (jArray.Count == 0) return jArray;
foreach (var item in jArray)
{
if (item is JObject)
{
var sortedItem = ((JObject)item).SortProperties();
sortedJArray.Add(sortedItem);
}
else if (item is JArray)
{
var sortedItem = ((JArray)item).SortProperties();
sortedJArray.Add(sortedItem);
}
}
return sortedJArray;
}
}
}

Просмотреть файл

@ -17,8 +17,6 @@ namespace Microsoft.Spark.CSharp.Proxy
ISparkContextProxy SparkContextProxy { get; }
ISparkConfProxy CreateSparkConf(bool loadDefaults = true);
ISparkContextProxy CreateSparkContext(ISparkConfProxy conf);
IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable);
IStructTypeProxy CreateStructType(List<StructField> fields);
IDStreamProxy CreateCSharpDStream(IDStreamProxy jdstream, byte[] func, string deserializer);
IDStreamProxy CreateCSharpTransformed2DStream(IDStreamProxy jdstream, IDStreamProxy jother, byte[] func, string deserializer, string deserializerOther);
IDStreamProxy CreateCSharpReducedWindowedDStream(IDStreamProxy jdstream, byte[] func, byte[] invFunc, int windowSeconds, int slideSeconds, string deserializer);

Просмотреть файл

@ -13,6 +13,7 @@ namespace Microsoft.Spark.CSharp.Proxy
{
internal interface ISqlContextProxy
{
IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy);
IDataFrameProxy ReadDataFrame(string path, StructType schema, Dictionary<string, string> options);
IDataFrameProxy JsonFile(string path);
IDataFrameProxy TextFile(string path, StructType schema, string delimiter);

Просмотреть файл

@ -6,7 +6,6 @@ using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Sql;
namespace Microsoft.Spark.CSharp.Proxy.Ipc
{
@ -77,7 +76,7 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
return
SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(
jvmDataFrameReference, "showString",
new object[] { numberOfRows , truncate }).ToString();
new object[] { numberOfRows, truncate }).ToString();
}
public bool IsLocal()

Просмотреть файл

@ -62,33 +62,6 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
return sparkContextProxy;
}
public IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable)
{
return new StructFieldIpcProxy(
new JvmObjectReference(
JvmBridge.CallStaticJavaMethod(
"org.apache.spark.sql.api.csharp.SQLUtils", "createStructField",
new object[] { name, dataType, isNullable }).ToString()
)
);
}
public IStructTypeProxy CreateStructType(List<StructField> fields)
{
var fieldsReference = fields.Select(s => (s.StructFieldProxy as StructFieldIpcProxy).JvmStructFieldReference).ToList().Cast<JvmObjectReference>();
var seq =
new JvmObjectReference(
JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils",
"toSeq", new object[] { fieldsReference }).ToString());
return new StructTypeIpcProxy(
new JvmObjectReference(
JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "createStructType", new object[] { seq }).ToString()
)
);
}
public IDStreamProxy CreateCSharpDStream(IDStreamProxy jdstream, byte[] func, string deserializer)
{
var jvmDStreamReference = SparkCLRIpcProxy.JvmBridge.CallConstructor("org.apache.spark.streaming.api.csharp.CSharpDStream",

Просмотреть файл

@ -116,7 +116,6 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
return new RDDIpcProxy(jvmRddReference);
}
//TODO - this implementation is slow. Replace with call to createRDDFromArray() in CSharpRDD
public IRDDProxy Parallelize(IEnumerable<byte[]> values, int numSlices)
{
var jvmRddReference = new JvmObjectReference((string)SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.api.csharp.CSharpRDD", "createRDDFromArray", new object[] { jvmSparkContextReference, values, numSlices }));

Просмотреть файл

@ -21,6 +21,17 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
this.jvmSqlContextReference = jvmSqlContextReference;
}
public IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy)
{
var rdd = new JvmObjectReference(SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "byteArrayRDDToAnyArrayRDD",
new object[] { (rddProxy as RDDIpcProxy).JvmRddReference }).ToString());
return new DataFrameIpcProxy(
new JvmObjectReference(
SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(jvmSqlContextReference, "applySchemaToPythonRDD",
new object[] { rdd, (structTypeProxy as StructTypeIpcProxy).JvmStructTypeReference }).ToString()), this);
}
public IDataFrameProxy ReadDataFrame(string path, StructType schema, Dictionary<string, string> options)
{
//TODO parameter Dictionary<string, string> options is not used right now - it is meant to be passed on to data sources
@ -44,7 +55,7 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc
new JvmObjectReference(
SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod(
"org.apache.spark.sql.api.csharp.SQLUtils", "loadTextFile",
new object[] {jvmSqlContextReference, path, delimiter, (schema.StructTypeProxy as StructTypeIpcProxy).JvmStructTypeReference}).ToString()
new object[] { jvmSqlContextReference, path, delimiter, schema.Json}).ToString()
), this
);
}

Просмотреть файл

@ -22,7 +22,7 @@ namespace Microsoft.Spark.CSharp.Sql
private readonly IDataFrameProxy dataFrameProxy;
[NonSerialized]
private readonly SparkContext sparkContext;
[NonSerialized]
private StructType schema;
[NonSerialized]
private RDD<Row> rdd;
@ -40,7 +40,7 @@ namespace Microsoft.Spark.CSharp.Sql
if (rdd == null)
{
rddProxy = dataFrameProxy.JavaToCSharp();
rdd = new RDD<Row>(rddProxy, sparkContext, SerializedMode.Row);
rdd = new RDD<Row>(rddProxy, sparkContext, SerializedMode.Row);
}
return rdd;
}
@ -137,7 +137,7 @@ namespace Microsoft.Spark.CSharp.Sql
/// </summary>
public void ShowSchema()
{
List<string> nameTypeList = Schema.Fields.Select(structField => string.Format("{0}:{1}", structField.Name, structField.DataType.SimpleString())).ToList();
var nameTypeList = Schema.Fields.Select(structField => structField.SimpleString);
Console.WriteLine(string.Join(", ", nameTypeList));
}
@ -145,18 +145,18 @@ namespace Microsoft.Spark.CSharp.Sql
/// Returns all of Rows in this DataFrame
/// </summary>
public IEnumerable<Row> Collect()
{
{
int port = RddProxy.CollectAndServe();
return Rdd.Collect(port).Cast<Row>();
}
/// <summary>
/// Converts the DataFrame to RDD of byte[]
/// Converts the DataFrame to RDD of Row
/// </summary>
/// <returns>resulting RDD</returns>
public RDD<byte[]> ToRDD() //RDD created using byte representation of GenericRow objects
public RDD<Row> ToRDD() //RDD created using byte representation of Row objects
{
return new RDD<byte[]>(dataFrameProxy.ToRDD(), sparkContext);
return Rdd;
}
/// <summary>

Просмотреть файл

@ -2,13 +2,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Microsoft.Spark.CSharp.Services;
namespace Microsoft.Spark.CSharp.Sql
@ -31,24 +27,24 @@ namespace Microsoft.Spark.CSharp.Sql
/// <summary>
/// Schema for the row.
/// </summary>
public abstract RowSchema GetSchema();
public abstract StructType GetSchema();
/// <summary>
/// Returns the value at position i.
/// </summary>
public abstract object Get(int i);
public abstract dynamic Get(int i);
/// <summary>
/// Returns the value of a given columnName.
/// </summary>
public abstract object Get(string columnName);
public abstract dynamic Get(string columnName);
/// <summary>
/// Returns the value at position i, the return value will be cast to type T.
/// </summary>
public T GetAs<T>(int i)
{
object o = Get(i);
dynamic o = Get(i);
try
{
T result = (T)o;
@ -66,7 +62,7 @@ namespace Microsoft.Spark.CSharp.Sql
/// </summary>
public T GetAs<T>(string columnName)
{
object o = Get(columnName);
dynamic o = Get(columnName);
try
{
T result = (T)o;
@ -80,145 +76,12 @@ namespace Microsoft.Spark.CSharp.Sql
}
}
/// <summary>
/// Schema of Row
/// </summary>
[Serializable]
public class RowSchema
{
public string type;
public List<ColumnSchema> columns;
private readonly Dictionary<string, int> columnName2Index = new Dictionary<string, int>();
public RowSchema(string type)
{
this.type = type;
this.columns = new List<ColumnSchema>();
}
public RowSchema(string type, List<ColumnSchema> cols)
{
int index = 0;
foreach (var col in cols)
{
string columnName = col.name;
//TODO - investigate the issue and uncomment the following checks
/*
* UDFs produce empty column name. Commenting out the following code at the time of upgrading to 1.5.2
*/
//if (string.IsNullOrEmpty(columnName))
//{
// throw new Exception(string.Format("Null column name at pos: {0}", index));
//}
//if (columnName2Index.ContainsKey(columnName))
//{
// throw new Exception(string.Format("duplicate column name ({0}) in pos ({1}) and ({2})",
// columnName, columnName2Index[columnName], index));
//}
columnName2Index[columnName] = index;
index++;
}
this.type = type;
this.columns = cols;
}
internal int GetIndexByColumnName(string ColumnName)
{
if (!columnName2Index.ContainsKey(ColumnName))
{
throw new Exception(string.Format("unknown ColumnName: {0}", ColumnName));
}
return columnName2Index[ColumnName];
}
public override string ToString()
{
string result;
if (columns.Any())
{
List<string> cols = new List<string>();
foreach (var col in columns)
{
cols.Add(col.ToString());
}
result = "{" +
string.Format("type: {0}, columns: [{1}]", type, string.Join(", ", cols.ToArray())) +
"}";
}
else
{
result = type;
}
return result;
}
internal static RowSchema ParseRowSchemaFromJson(string json)
{
JObject joType = JObject.Parse(json);
string type = joType["type"].ToString();
List<ColumnSchema> columns = new List<ColumnSchema>();
List<JToken> jtFields = joType["fields"].Children().ToList();
foreach (JToken jtField in jtFields)
{
ColumnSchema col = ColumnSchema.ParseColumnSchemaFromJson(jtField.ToString());
columns.Add(col);
}
return new RowSchema(type, columns);
}
}
/// <summary>
/// Schema for column
/// </summary>
[Serializable]
public class ColumnSchema
{
public string name;
public RowSchema type;
public bool nullable;
public override string ToString()
{
string str = string.Format("name: {0}, type: {1}, nullable: {2}", name, type, nullable);
return "{" + str + "}";
}
internal static ColumnSchema ParseColumnSchemaFromJson(string json)
{
ColumnSchema col = new ColumnSchema();
JObject joField = JObject.Parse(json);
col.name = joField["name"].ToString();
col.nullable = (bool)(joField["nullable"]);
JToken jtType = joField["type"];
if (jtType.Type == JTokenType.String)
{
col.type = new RowSchema(joField["type"].ToString());
}
else
{
col.type = RowSchema.ParseRowSchemaFromJson(joField["type"].ToString());
}
return col;
}
}
[Serializable]
internal class RowImpl : Row
{
private readonly RowSchema schema;
private readonly object[] values;
private readonly StructType schema;
public dynamic[] Values { get { return values; } }
private readonly dynamic[] values;
private readonly int columnCount;
@ -229,17 +92,15 @@ namespace Microsoft.Spark.CSharp.Sql
return Get(index);
}
}
internal RowImpl(object data, RowSchema schema)
internal RowImpl(dynamic data, StructType schema)
{
if (data is object[])
if (data is dynamic[])
{
values = data as object[];
values = data as dynamic[];
}
else if (data is List<object>)
else if (data is List<dynamic>)
{
values = (data as List<object>).ToArray();
values = (data as List<dynamic>).ToArray();
}
else
{
@ -249,7 +110,7 @@ namespace Microsoft.Spark.CSharp.Sql
this.schema = schema;
columnCount = values.Count();
int schemaColumnCount = this.schema.columns.Count();
int schemaColumnCount = this.schema.Fields.Count();
if (columnCount != schemaColumnCount)
{
throw new Exception(string.Format("column count inferred from data ({0}) and schema ({1}) mismatch", columnCount, schemaColumnCount));
@ -263,12 +124,12 @@ namespace Microsoft.Spark.CSharp.Sql
return columnCount;
}
public override RowSchema GetSchema()
public override StructType GetSchema()
{
return schema;
}
public override object Get(int i)
public override dynamic Get(int i)
{
if (i >= columnCount)
{
@ -278,9 +139,9 @@ namespace Microsoft.Spark.CSharp.Sql
return values[i];
}
public override object Get(string columnName)
public override dynamic Get(string columnName)
{
int index = schema.GetIndexByColumnName(columnName);
int index = schema.Fields.FindIndex(f => f.Name == columnName); // case sensitive
return Get(index);
}
@ -305,21 +166,69 @@ namespace Microsoft.Spark.CSharp.Sql
private void Initialize()
{
int index = 0;
foreach (var col in schema.columns)
foreach (var field in schema.Fields)
{
if (col.type.columns.Any()) // this column itself is a sub-row
if (field.DataType is ArrayType)
{
object value = values[index];
Func<DataType, int, StructType> convertArrayTypeToStructTypeFunc = (dataType, length) =>
{
StructField[] fields = new StructField[length];
for(int i = 0; i < length ; i++)
{
fields[i] = new StructField(string.Format("_array_{0}", i), dataType);
}
return new StructType(fields);
};
var elementType = (field.DataType as ArrayType).ElementType;
// Note: When creating object from json, PySpark converts Json array to Python List (https://github.com/apache/spark/blob/branch-1.4/python/pyspark/sql/types.py, _create_cls(dataType)),
// then Pyrolite unpickler converts Python List to C# ArrayList (https://github.com/irmen/Pyrolite/blob/v4.10/README.txt). So values[index] should be of type ArrayList;
// In case Python changes its implementation, which means value is not of type ArrayList, try cast to object[] because Pyrolite unpickler convert Python Tuple to C# object[].
object[] valueOfArray = values[index] is ArrayList ? (values[index] as ArrayList).ToArray() : values[index] as object[];
if (valueOfArray == null)
{
throw new ArgumentException("Cannot parse data of ArrayType: " + field.Name);
}
values[index] = new RowImpl(valueOfArray, elementType as StructType ?? convertArrayTypeToStructTypeFunc(elementType, valueOfArray.Length)).values;
}
else if (field.DataType is MapType)
{
//TODO
throw new NotImplementedException();
}
else if (field.DataType is StructType)
{
dynamic value = values[index];
if (value != null)
{
RowImpl subRow = new RowImpl(values[index], col.type);
var subRow = new RowImpl(values[index], field.DataType as StructType);
values[index] = subRow;
}
}
else if (field.DataType is DecimalType)
{
//TODO
throw new NotImplementedException();
}
else if (field.DataType is DateType)
{
//TODO
throw new NotImplementedException();
}
else if (field.DataType is StringType)
{
if (values[index] != null) values[index] = values[index].ToString();
}
else
{
values[index] = values[index];
}
index++;
}
}
}
}

Просмотреть файл

@ -1,11 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Razorvine.Pickle;
namespace Microsoft.Spark.CSharp.Sql
@ -65,7 +60,8 @@ namespace Microsoft.Spark.CSharp.Sql
/// <returns></returns>
public Row GetRow()
{
var row = new RowImpl(GetValues(Values), RowSchema.ParseRowSchemaFromJson(Schema));
var schema = DataType.ParseDataTypeFromJson(Schema) as StructType;
var row = new RowImpl(GetValues(Values), schema);
//Resetting schema here so that rows from multiple DataFrames can be processed in the same AppDomain
//next row will have schema - so resetting is fine

Просмотреть файл

@ -8,6 +8,7 @@ using System.Text;
using System.Threading.Tasks;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Interop;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Proxy;
namespace Microsoft.Spark.CSharp.Sql
@ -39,9 +40,16 @@ namespace Microsoft.Spark.CSharp.Sql
return new DataFrame(sqlContextProxy.ReadDataFrame(path, schema, options), sparkContext);
}
public DataFrame CreateDataFrame(RDD<byte[]> rdd, StructType schema)
public DataFrame CreateDataFrame(RDD<object[]> rdd, StructType schema)
{
throw new NotImplementedException();
// Note: This is for pickling RDD, convert to RDD<byte[]> which happens in CSharpWorker.
// The below sqlContextProxy.CreateDataFrame() will call byteArrayRDDToAnyArrayRDD() of SQLUtils.scala which only accept RDD of type RDD[Array[Byte]].
// In byteArrayRDDToAnyArrayRDD() of SQLUtils.scala, the SerDeUtil.pythonToJava() will be called which is a mapPartitions inside.
// It will be executed until the CSharpWorker finishes Pickling to RDD[Array[Byte]].
var rddRow = rdd.Map(r => r);
rddRow.serializedMode = SerializedMode.Row;
return new DataFrame(sqlContextProxy.CreateDataFrame(rddRow.RddProxy, schema.StructTypeProxy), sparkContext);
}
/// <summary>

Просмотреть файл

@ -1,135 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.CSharp.Interop;
using Microsoft.Spark.CSharp.Proxy;
using Microsoft.Spark.CSharp.Proxy.Ipc;
namespace Microsoft.Spark.CSharp.Sql
{
/// <summary>
/// Schema of DataFrame
/// </summary>
public class StructType
{
private readonly IStructTypeProxy structTypeProxy;
internal IStructTypeProxy StructTypeProxy
{
get
{
return structTypeProxy;
}
}
public List<StructField> Fields //TODO - avoid calling method everytime
{
get
{
var structTypeFieldJvmObjectReferenceList =
structTypeProxy.GetStructTypeFields();
var structFieldList = new List<StructField>(structTypeFieldJvmObjectReferenceList.Count);
structFieldList.AddRange(
structTypeFieldJvmObjectReferenceList.Select(
structTypeFieldJvmObjectReference => new StructField(structTypeFieldJvmObjectReference)));
return structFieldList;
}
}
public string ToJson()
{
return structTypeProxy.ToJson();
}
internal StructType(IStructTypeProxy structTypeProxy)
{
this.structTypeProxy = structTypeProxy;
}
public static StructType CreateStructType(List<StructField> structFields)
{
return new StructType(SparkCLREnvironment.SparkCLRProxy.CreateStructType(structFields));
}
}
/// <summary>
/// Schema for DataFrame column
/// </summary>
public class StructField
{
private readonly IStructFieldProxy structFieldProxy;
internal IStructFieldProxy StructFieldProxy
{
get
{
return structFieldProxy;
}
}
public string Name
{
get
{
return structFieldProxy.GetStructFieldName();
}
}
public DataType DataType
{
get
{
return new DataType(structFieldProxy.GetStructFieldDataType());
}
}
public bool IsNullable
{
get
{
return structFieldProxy.GetStructFieldIsNullable();
}
}
internal StructField(IStructFieldProxy strucFieldProxy)
{
structFieldProxy = strucFieldProxy;
}
public static StructField CreateStructField(string name, string dataType, bool isNullable)
{
return new StructField(SparkCLREnvironment.SparkCLRProxy.CreateStructField(name, dataType, isNullable));
}
}
public class DataType
{
private readonly IStructDataTypeProxy structDataTypeProxy;
internal IStructDataTypeProxy StructDataTypeProxy
{
get
{
return structDataTypeProxy;
}
}
internal DataType(IStructDataTypeProxy structDataTypeProxy)
{
this.structDataTypeProxy = structDataTypeProxy;
}
public override string ToString()
{
return structDataTypeProxy.GetDataTypeString();
}
public string SimpleString()
{
return structDataTypeProxy.GetDataTypeSimpleString();
}
}
}

Просмотреть файл

@ -0,0 +1,360 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Microsoft.Spark.CSharp.Interop.Ipc;
using Microsoft.Spark.CSharp.Proxy;
using Microsoft.Spark.CSharp.Proxy.Ipc;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.Spark.CSharp.Sql
{
[Serializable]
public abstract class DataType
{
/// <summary>
/// Trim "Type" in the end from class name, ToLower() to align with Scala.
/// </summary>
public string TypeName
{
get { return NormalizeTypeName(GetType().Name); }
}
/// <summary>
/// return TypeName by default, subclass can override it
/// </summary>
public virtual string SimpleString
{
get { return TypeName; }
}
/// <summary>
/// return only type: TypeName by default, subclass can override it
/// </summary>
internal virtual object JsonValue { get { return TypeName; } }
public string Json
{
get
{
var jObject = JsonValue is JObject ? ((JObject)JsonValue).SortProperties() : JsonValue;
return JsonConvert.SerializeObject(jObject, Formatting.None);
}
}
public static DataType ParseDataTypeFromJson(string json)
{
return ParseDataTypeFromJson(JToken.Parse(json));
}
protected static DataType ParseDataTypeFromJson(JToken json)
{
if (json.Type == JTokenType.Object) // {name: address, type: {type: struct,...},...}
{
JToken type;
var typeJObject = (JObject)json;
if (typeJObject.TryGetValue("type", out type))
{
Type complexType;
if ((complexType = ComplexTypes.FirstOrDefault(ct => NormalizeTypeName(ct.Name) == type.ToString())) != default(Type))
{
return ((ComplexType)Activator.CreateInstance(complexType, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance
, null, new object[] { typeJObject }, null)); // create new instance of ComplexType
}
if (type.ToString() == "udt")
{
// TODO
throw new NotImplementedException();
}
}
throw new ArgumentException(string.Format("Could not parse data type: {0}", type));
}
else // {name: age, type: bigint,...} // TODO: validate more JTokenType other than Object
{
return ParseAtomicType(json);
}
throw new ArgumentException(string.Format("Could not parse data type: {0}", json));
}
private static AtomicType ParseAtomicType(JToken type)
{
Type atomicType;
if ((atomicType = AtomicTypes.FirstOrDefault(at => NormalizeTypeName(at.Name) == type.ToString())) != default(Type))
{
return (AtomicType)Activator.CreateInstance(atomicType); // create new instance of AtomicType
}
Match fixedDecimal = DecimalType.FixedDecimal.Match(type.ToString());
if (fixedDecimal.Success)
{
return new DecimalType(int.Parse(fixedDecimal.Groups[1].Value), int.Parse(fixedDecimal.Groups[2].Value));
}
throw new ArgumentException(string.Format("Could not parse data type: {0}", type));
}
[NonSerialized]
private static readonly Type[] AtomicTypes = typeof(AtomicType).Assembly.GetTypes().Where(type =>
type.IsSubclassOf(typeof(AtomicType))).ToArray();
[NonSerialized]
private static readonly Type[] ComplexTypes = typeof(ComplexType).Assembly.GetTypes().Where(type =>
type.IsSubclassOf(typeof(ComplexType))).ToArray();
[NonSerialized]
private static readonly Func<string, string> NormalizeTypeName = s => s.Substring(0, s.Length - 4).ToLower(); // trim "Type" at the end of type name
}
[Serializable]
public class AtomicType : DataType
{
}
[Serializable]
public abstract class ComplexType : DataType
{
public abstract DataType FromJson(JObject json);
public DataType FromJson(string json)
{
return FromJson(JObject.Parse(json));
}
}
[Serializable]
public class NullType : AtomicType { }
[Serializable]
public class StringType : AtomicType { }
[Serializable]
public class BinaryType : AtomicType { }
[Serializable]
public class BooleanType : AtomicType { }
[Serializable]
public class DateType : AtomicType { }
[Serializable]
public class TimestampType : AtomicType { }
[Serializable]
public class DoubleType : AtomicType { }
[Serializable]
public class FloatType : AtomicType { }
[Serializable]
public class ByteType : AtomicType { }
[Serializable]
public class IntegerType : AtomicType { }
[Serializable]
public class LongType : AtomicType { }
[Serializable]
public class ShortType : AtomicType { }
[Serializable]
public class DecimalType : AtomicType
{
public static Regex FixedDecimal = new Regex(@"decimal\((\d+),\s(\d+)\)");
private int? precision, scale;
public DecimalType(int? precision = null, int? scale = null)
{
this.precision = precision;
this.scale = scale;
}
internal override object JsonValue
{
get { throw new NotImplementedException(); }
}
public DataType FromJson(JObject json)
{
throw new NotImplementedException();
}
}
[Serializable]
public class ArrayType : ComplexType
{
public DataType ElementType { get { return elementType; } }
public bool ContainsNull { get { return containsNull; } }
public ArrayType(DataType elementType, bool containsNull = true)
{
this.elementType = elementType;
this.containsNull = containsNull;
}
internal ArrayType(JObject json)
{
FromJson(json);
}
public override string SimpleString
{
get { return string.Format("array<{0}>", elementType.SimpleString); }
}
internal override object JsonValue
{
get
{
return new JObject(
new JProperty("type", TypeName),
new JProperty("elementType", elementType.JsonValue),
new JProperty("containsNull", containsNull));
}
}
public override sealed DataType FromJson(JObject json)
{
elementType = ParseDataTypeFromJson(json["elementType"]);
containsNull = (bool)json["containsNull"];
return this;
}
private DataType elementType;
private bool containsNull;
}
[Serializable]
public class MapType : ComplexType
{
internal override object JsonValue
{
get { throw new NotImplementedException(); }
}
public override DataType FromJson(JObject json)
{
throw new NotImplementedException();
}
}
[Serializable]
public class StructField : ComplexType
{
public string Name { get { return name; } }
public DataType DataType { get { return dataType; } }
public bool IsNullable { get { return isNullable; } }
public JObject Metadata { get { return metadata; } }
public StructField(string name, DataType dataType, bool isNullable = true, JObject metadata = null)
{
this.name = name;
this.dataType = dataType;
this.isNullable = isNullable;
this.metadata = metadata ?? new JObject();
}
internal StructField(JObject json)
{
FromJson(json);
}
public override string SimpleString { get { return string.Format(@"{0}:{1}", name, dataType.SimpleString); } }
internal override object JsonValue
{
get
{
return new JObject(
new JProperty("name", name),
new JProperty("type", dataType.JsonValue),
new JProperty("nullable", isNullable),
new JProperty("metadata", metadata));
}
}
public override sealed DataType FromJson(JObject json)
{
name = json["name"].ToString();
dataType = ParseDataTypeFromJson(json["type"]);
isNullable = (bool)json["nullable"];
metadata = (JObject)json["metadata"];
return this;
}
private string name;
private DataType dataType;
private bool isNullable;
[NonSerialized]
private JObject metadata;
}
[Serializable]
public class StructType : ComplexType
{
public List<StructField> Fields { get { return fields; } }
internal IStructTypeProxy StructTypeProxy
{
get
{
return structTypeProxy ??
new StructTypeIpcProxy(
new JvmObjectReference(SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "createSchema",
new object[] { Json }).ToString()));
}
}
public StructType(IEnumerable<StructField> fields)
{
this.fields = fields.ToList();
}
internal StructType(JObject json)
{
FromJson(json);
}
internal StructType(IStructTypeProxy structTypeProxy)
{
this.structTypeProxy = structTypeProxy;
var jsonSchema = (structTypeProxy as StructTypeIpcProxy).ToJson();
FromJson(jsonSchema);
}
public override string SimpleString
{
get { return string.Format(@"struct<{0}>", string.Join(",", fields.Select(f => f.SimpleString))); }
}
internal override object JsonValue
{
get
{
return new JObject(
new JProperty("type", TypeName),
new JProperty("fields", fields.Select(f => f.JsonValue).ToArray()));
}
}
public override sealed DataType FromJson(JObject json)
{
var fieldsJObjects = json["fields"].Select(f => (JObject)f);
fields = fieldsJObjects.Select(fieldJObject => (new StructField(fieldJObject))).ToList();
return this;
}
[NonSerialized]
private readonly IStructTypeProxy structTypeProxy;
private List<StructField> fields;
}
}

Просмотреть файл

@ -40,6 +40,9 @@
<HintPath>..\packages\Moq.4.2.1510.2205\lib\net40\Moq.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="Newtonsoft.Json">
<HintPath>..\packages\Newtonsoft.Json.7.0.1\lib\net45\Newtonsoft.Json.dll</HintPath>
</Reference>
<Reference Include="nunit.framework, Version=3.0.5813.39031, Culture=neutral, PublicKeyToken=2638cd05610744eb, processorArchitecture=MSIL">
<HintPath>..\packages\NUnit.3.0.1\lib\net45\nunit.framework.dll</HintPath>
<Private>True</Private>
@ -127,4 +130,4 @@
<Target Name="AfterBuild">
</Target>
-->
</Project>
</Project>

Просмотреть файл

@ -422,7 +422,7 @@ namespace AdapterTest
"123",
"Bill"
},
RowSchema.ParseRowSchemaFromJson(jsonSchema))
DataType.ParseDataTypeFromJson(jsonSchema) as StructType)
};
mockDataFrameProxy.Setup(m => m.JavaToCSharp()).Returns(new MockRddProxy(rows));
@ -435,14 +435,6 @@ namespace AdapterTest
Assert.IsNotNull(rdd);
mockDataFrameProxy.Verify(m => m.JavaToCSharp(), Times.Once);
mockDataFrameProxy.Reset();
mockStructTypeProxy.Reset();
rdd = dataFrame.Rdd;
Assert.IsNotNull(rdd);
mockDataFrameProxy.Verify(m => m.JavaToCSharp(), Times.Never);
mockStructTypeProxy.Verify(m => m.ToJson(), Times.Never);
}
[Test]

Просмотреть файл

@ -14,7 +14,7 @@ namespace AdapterTest.Mocks
throw new NotImplementedException();
}
public override RowSchema GetSchema()
public override StructType GetSchema()
{
throw new NotImplementedException();
}

Просмотреть файл

@ -44,17 +44,6 @@ namespace AdapterTest.Mocks
return new MockSparkContextProxy(conf);
}
public IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable)
{
throw new NotImplementedException();
}
public IStructTypeProxy CreateStructType(List<StructField> fields)
{
throw new NotImplementedException();
}
public ISparkContextProxy SparkContextProxy
{
get { throw new NotImplementedException(); }

Просмотреть файл

@ -27,6 +27,11 @@ namespace AdapterTest.Mocks
mockSparkContextProxy = scProxy;
}
public IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy)
{
throw new NotImplementedException();
}
public IDataFrameProxy ReadDataFrame(string path, StructType schema, System.Collections.Generic.Dictionary<string, string> options)
{
throw new NotImplementedException();

Просмотреть файл

@ -1,6 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Moq" version="4.2.1510.2205" targetFramework="net45" />
<package id="Newtonsoft.Json" version="7.0.1" targetFramework="net45" />
<package id="NUnit" version="3.0.1" targetFramework="net45" />
<package id="Razorvine.Pyrolite" version="4.10.0.0" targetFramework="net45" />
<package id="Razorvine.Serpent" version="1.12.0.0" targetFramework="net45" />

Просмотреть файл

@ -6,6 +6,7 @@ using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.Spark.CSharp.Core;
using Microsoft.Spark.CSharp.Sql;
using NUnit.Framework;
@ -27,6 +28,116 @@ namespace Microsoft.Spark.CSharp.Samples
return sqlContext ?? (sqlContext = new SqlContext(SparkCLRSamples.SparkContext));
}
/// <summary>
/// Sample to create DataFrame. The RDD is generated from SparkContext Parallelize; the schema is created via object creating.
/// </summary>
[Sample]
internal static void DFCreateDataFrameSample()
{
var schemaPeople = new StructType(new List<StructField>
{
new StructField("id", new StringType()),
new StructField("name", new StringType()),
new StructField("age", new IntegerType()),
new StructField("address", new StructType(new List<StructField>
{
new StructField("city", new StringType()),
new StructField("state", new StringType())
})),
new StructField("phone numbers", new ArrayType(new StringType()))
});
var rddPeople = SparkCLRSamples.SparkContext.Parallelize(
new List<object[]>
{
new object[] { "123", "Bill", 43, new object[]{ "Columbus", "Ohio" }, new string[]{ "Tel1", "Tel2" } },
new object[] { "456", "Steve", 34, new object[]{ "Seattle", "Washington" }, new string[]{ "Tel3", "Tel4" } }
});
var dataFramePeople = GetSqlContext().CreateDataFrame(rddPeople, schemaPeople);
Console.WriteLine("------ Schema of People Data Frame:\r\n");
dataFramePeople.ShowSchema();
Console.WriteLine();
var collected = dataFramePeople.Collect().ToArray();
foreach (var people in collected)
{
string id = people.Get("id");
string name = people.Get("name");
int age = people.Get("age");
Row address = people.Get("address");
string city = address.Get("city");
string state = address.Get("state");
object[] phoneNumbers = people.Get("phone numbers");
Console.WriteLine("id:{0}, name:{1}, age:{2}, address:(city:{3},state:{4}), phoneNumbers:[{5},{6}]\r\n", id, name, age, city, state, phoneNumbers[0], phoneNumbers[1]);
}
if (SparkCLRSamples.Configuration.IsValidationEnabled)
{
Assert.AreEqual(2, dataFramePeople.Rdd.Count());
Assert.AreEqual(schemaPeople.Json, dataFramePeople.Schema.Json);
}
}
/// <summary>
/// Sample to create DataFrame. The RDD is generated from SparkContext TextFile; the schema is created from Json.
/// </summary>
[Sample]
internal static void DFCreateDataFrameSample2()
{
var rddRequestsLog = SparkCLRSamples.SparkContext.TextFile(SparkCLRSamples.Configuration.GetInputDataPath(RequestsLog), 1).Map(r => r.Split(',').Select(s => (object)s).ToArray());
const string schemaRequestsLogJson = @"{
""fields"": [{
""metadata"": {},
""name"": ""guid"",
""nullable"": false,
""type"": ""string""
},
{
""metadata"": {},
""name"": ""datacenter"",
""nullable"": false,
""type"": ""string""
},
{
""metadata"": {},
""name"": ""abtestid"",
""nullable"": false,
""type"": ""string""
},
{
""metadata"": {},
""name"": ""traffictype"",
""nullable"": false,
""type"": ""string""
}],
""type"": ""struct""
}";
// create schema from parsing Json
StructType requestsLogSchema = DataType.ParseDataTypeFromJson(schemaRequestsLogJson) as StructType;
var dataFrameRequestsLog = GetSqlContext().CreateDataFrame(rddRequestsLog, requestsLogSchema);
Console.WriteLine("------ Schema of RequestsLog Data Frame:");
dataFrameRequestsLog.ShowSchema();
Console.WriteLine();
var collected = dataFrameRequestsLog.Collect().ToArray();
foreach (var request in collected)
{
string guid = request.Get("guid");
string datacenter = request.Get("datacenter");
string abtestid = request.Get("abtestid");
string traffictype = request.Get("traffictype");
Console.WriteLine("guid:{0}, datacenter:{1}, abtestid:{2}, traffictype:{3}\r\n", guid, datacenter, abtestid, traffictype);
}
if (SparkCLRSamples.Configuration.IsValidationEnabled)
{
Assert.AreEqual(10, collected.Length);
Assert.AreEqual(Regex.Replace(schemaRequestsLogJson, @"\s", string.Empty), Regex.Replace(dataFrameRequestsLog.Schema.Json, @"\s", string.Empty));
}
}
/// <summary>
/// Sample to show schema of DataFrame
/// </summary>
@ -42,10 +153,10 @@ namespace Microsoft.Spark.CSharp.Samples
/// Sample to get schema of DataFrame in json format
/// </summary>
[Sample]
internal static void DFGetSchemaToJsonSample()
internal static void DFGetSchemaJsonSample()
{
var peopleDataFrame = GetSqlContext().JsonFile(SparkCLRSamples.Configuration.GetInputDataPath(PeopleJson));
string json = peopleDataFrame.Schema.ToJson();
string json = peopleDataFrame.Schema.Json;
Console.WriteLine("schema in json format: {0}", json);
}
@ -56,7 +167,7 @@ namespace Microsoft.Spark.CSharp.Samples
internal static void DFCollectSample()
{
var peopleDataFrame = GetSqlContext().JsonFile(SparkCLRSamples.Configuration.GetInputDataPath(PeopleJson));
IEnumerable<Row> rows = peopleDataFrame.Collect();
var rows = peopleDataFrame.Collect().ToArray();
Console.WriteLine("peopleDataFrame:");
int i = 0;
@ -122,15 +233,13 @@ namespace Microsoft.Spark.CSharp.Samples
[Sample]
internal static void DFTextFileLoadDataFrameSample()
{
var requestsSchema = StructType.CreateStructType(
new List<StructField>
{
StructField.CreateStructField("guid", "string", false),
StructField.CreateStructField("datacenter", "string", false),
StructField.CreateStructField("abtestid", "string", false),
StructField.CreateStructField("traffictype", "string", false),
}
);
var requestsSchema = new StructType(new List<StructField>
{
new StructField("guid", new StringType(), false),
new StructField("datacenter", new StringType(), false),
new StructField("abtestid", new StringType(), false),
new StructField("traffictype", new StringType(), false),
});
var requestsDateFrame = GetSqlContext().TextFile(SparkCLRSamples.Configuration.GetInputDataPath(RequestsLog), requestsSchema);
requestsDateFrame.RegisterTempTable("requests");
@ -154,18 +263,17 @@ namespace Microsoft.Spark.CSharp.Samples
private static DataFrame GetMetricsDataFrame()
{
var metricsSchema = StructType.CreateStructType(
new List<StructField>
var metricsSchema = new StructType(
new[]
{
StructField.CreateStructField("unknown", "string", false),
StructField.CreateStructField("date", "string", false),
StructField.CreateStructField("time", "string", false),
StructField.CreateStructField("guid", "string", false),
StructField.CreateStructField("lang", "string", false),
StructField.CreateStructField("country", "string", false),
StructField.CreateStructField("latency", "integer", false)
}
);
new StructField("unknown", new StringType(), false),
new StructField("date", new StringType(), false),
new StructField("time", new StringType(), false),
new StructField("guid", new StringType(), false),
new StructField("lang", new StringType(), false),
new StructField("country", new StringType(), false),
new StructField("latency", new StringType(), false),
});
return
GetSqlContext()
@ -236,8 +344,8 @@ namespace Microsoft.Spark.CSharp.Samples
{
var name = peopleDataFrameSchemaField.Name;
var dataType = peopleDataFrameSchemaField.DataType;
var stringVal = dataType.ToString();
var simpleStringVal = dataType.SimpleString();
var stringVal = dataType.TypeName;
var simpleStringVal = dataType.SimpleString;
var isNullable = peopleDataFrameSchemaField.IsNullable;
Console.WriteLine("Name={0}, DT.string={1}, DT.simplestring={2}, DT.isNullable={3}", name, stringVal, simpleStringVal, isNullable);
}
@ -388,7 +496,7 @@ namespace Microsoft.Spark.CSharp.Samples
var singleValueReplaced = peopleDataFrame.Replace("Bill", "Bill.G");
singleValueReplaced.Show();
var multiValueReplaced = peopleDataFrame.ReplaceAll(new List<int> { 14, 34 }, new List<int> { 44, 54 });
multiValueReplaced.Show();
@ -853,7 +961,7 @@ namespace Microsoft.Spark.CSharp.Samples
Console.WriteLine("peopleDataFrame:");
var count = 0;
RowSchema schema = null;
StructType schema = null;
Row firstRow = null;
foreach (var row in rows)
{
@ -939,43 +1047,43 @@ namespace Microsoft.Spark.CSharp.Samples
/// Verify the schema of people dataframe.
/// </summary>
/// <param name="schema"> RowSchema of people DataFrame </param>
internal static void VerifySchemaOfPeopleDataFrame(RowSchema schema)
internal static void VerifySchemaOfPeopleDataFrame(StructType schema)
{
Assert.IsNotNull(schema);
Assert.AreEqual("struct", schema.type);
Assert.IsNotNull(schema.columns);
Assert.AreEqual(4, schema.columns.Count);
Assert.AreEqual("struct", schema.TypeName);
Assert.IsNotNull(schema.Fields);
Assert.AreEqual(4, schema.Fields.Count);
// name
var nameColSchema = schema.columns.Find(c => c.name.Equals("name"));
var nameColSchema = schema.Fields.Find(c => c.Name.Equals("name"));
Assert.IsNotNull(nameColSchema);
Assert.AreEqual("name", nameColSchema.name);
Assert.IsTrue(nameColSchema.nullable);
Assert.AreEqual("string", nameColSchema.type.ToString());
Assert.AreEqual("name", nameColSchema.Name);
Assert.IsTrue(nameColSchema.IsNullable);
Assert.AreEqual("string", nameColSchema.DataType.TypeName);
// id
var idColSchema = schema.columns.Find(c => c.name.Equals("id"));
var idColSchema = schema.Fields.Find(c => c.Name.Equals("id"));
Assert.IsNotNull(idColSchema);
Assert.AreEqual("id", idColSchema.name);
Assert.IsTrue(idColSchema.nullable);
Assert.AreEqual("string", nameColSchema.type.ToString());
Assert.AreEqual("id", idColSchema.Name);
Assert.IsTrue(idColSchema.IsNullable);
Assert.AreEqual("string", nameColSchema.DataType.TypeName);
// age
var ageColSchema = schema.columns.Find(c => c.name.Equals("age"));
var ageColSchema = schema.Fields.Find(c => c.Name.Equals("age"));
Assert.IsNotNull(ageColSchema);
Assert.AreEqual("age", ageColSchema.name);
Assert.IsTrue(ageColSchema.nullable);
Assert.AreEqual("long", ageColSchema.type.ToString());
Assert.AreEqual("age", ageColSchema.Name);
Assert.IsTrue(ageColSchema.IsNullable);
Assert.AreEqual("long", ageColSchema.DataType.TypeName);
// address
var addressColSchema = schema.columns.Find(c => c.name.Equals("address"));
var addressColSchema = schema.Fields.Find(c => c.Name.Equals("address"));
Assert.IsNotNull(addressColSchema);
Assert.AreEqual("address", addressColSchema.name);
Assert.IsTrue(addressColSchema.nullable);
Assert.IsNotNull(addressColSchema.type);
Assert.AreEqual("struct", addressColSchema.type.type);
Assert.IsNotNull(addressColSchema.type.columns.Find(c => c.name.Equals("state")));
Assert.IsNotNull(addressColSchema.type.columns.Find(c => c.name.Equals("city")));
Assert.AreEqual("address", addressColSchema.Name);
Assert.IsTrue(addressColSchema.IsNullable);
Assert.IsNotNull(addressColSchema.DataType);
Assert.AreEqual("struct", addressColSchema.DataType.TypeName);
Assert.IsNotNull(((StructType)addressColSchema.DataType).Fields.Find(c => c.Name.Equals("state")));
Assert.IsNotNull(((StructType)addressColSchema.DataType).Fields.Find(c => c.Name.Equals("city")));
}
/// <summary>
@ -1128,25 +1236,25 @@ namespace Microsoft.Spark.CSharp.Samples
if (x == null && y == null) return true;
if (x == null && y != null || x != null && y == null) return false;
foreach (var col in x.GetSchema().columns)
foreach (var col in x.GetSchema().Fields)
{
if (!y.GetSchema().columns.Any(c => c.ToString() == col.ToString())) return false;
if (!y.GetSchema().Fields.Any(c => c.Name == col.Name)) return false;
if (col.type.columns.Any())
if (col.DataType is StructType)
{
if (!IsRowEqual(x.GetAs<Row>(col.name), y.GetAs<Row>(col.name), columnsComparer)) return false;
if (!IsRowEqual(x.GetAs<Row>(col.Name), y.GetAs<Row>(col.Name), columnsComparer)) return false;
}
else if (x.Get(col.name) == null && y.Get(col.name) != null || x.Get(col.name) != null && y.Get(col.name) == null) return false;
else if (x.Get(col.name) != null && y.Get(col.name) != null)
else if (x.Get(col.Name) == null && y.Get(col.Name) != null || x.Get(col.Name) != null && y.Get(col.Name) == null) return false;
else if (x.Get(col.Name) != null && y.Get(col.Name) != null)
{
Func<object, object, bool> colComparer;
if (columnsComparer != null && columnsComparer.TryGetValue(col.name, out colComparer))
if (columnsComparer != null && columnsComparer.TryGetValue(col.Name, out colComparer))
{
if (!colComparer(x.Get(col.name), y.Get(col.name))) return false;
if (!colComparer(x.Get(col.Name), y.Get(col.Name))) return false;
}
else
{
if (x.Get(col.name).ToString() != y.Get(col.name).ToString()) return false;
if (x.Get(col.Name).ToString() != y.Get(col.Name).ToString()) return false;
}
}
}

Просмотреть файл

@ -33,6 +33,10 @@
<WarningLevel>4</WarningLevel>
</PropertyGroup>
<ItemGroup>
<Reference Include="Newtonsoft.Json, Version=7.0.0.0, Culture=neutral, PublicKeyToken=30ad4fe6b2a6aeed, processorArchitecture=MSIL">
<HintPath>..\..\packages\Newtonsoft.Json.7.0.1\lib\net45\Newtonsoft.Json.dll</HintPath>
<Private>True</Private>
</Reference>
<Reference Include="nunit.framework, Version=3.0.5813.39031, Culture=neutral, PublicKeyToken=2638cd05610744eb, processorArchitecture=MSIL">
<HintPath>..\..\packages\NUnit.3.0.1\lib\net45\nunit.framework.dll</HintPath>
<Private>True</Private>

Просмотреть файл

@ -1,4 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Newtonsoft.Json" version="7.0.1" targetFramework="net45" />
<package id="NUnit" version="3.0.1" targetFramework="net45" />
</packages>

Просмотреть файл

@ -3,14 +3,16 @@
package org.apache.spark.sql.api.csharp
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.io.{ByteArrayOutputStream, DataOutputStream}
import org.apache.spark.SparkContext
import org.apache.spark.api.csharp.SerDe
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataType, FloatType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, FloatType, StructType}
import org.apache.spark.sql._
import java.util.{ArrayList => JArrayList}
/**
* Utility functions for DataFrame in SparkCLR
@ -31,10 +33,6 @@ object SQLUtils {
arr.toSeq
}
def createStructType(fields : Seq[StructField]): StructType = {
StructType(fields)
}
def getSQLDataType(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
@ -54,17 +52,6 @@ object SQLUtils {
}
}
def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
val dtObj = getSQLDataType(dataType)
StructField(name, dtObj, nullable)
}
def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
val rowRDD = rdd.map(bytesToRow(_, schema))
sqlContext.createDataFrame(rowRDD, schema)
}
def dfToRowRDD(df: DataFrame): RDD[Array[Byte]] = {
df.map(r => rowToCSharpBytes(r))
}
@ -77,14 +64,6 @@ object SQLUtils {
}
}
private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = {
val bis = new ByteArrayInputStream(bytes)
val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis)
Row.fromSeq((0 until num).map { i =>
doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
}.toSeq)
}
private[this] def rowToCSharpBytes(row: Row): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@ -176,8 +155,11 @@ object SQLUtils {
dfReader.load(path)
}
def loadTextFile(sqlContext: SQLContext, path: String, delimiter: String, schema: StructType) : DataFrame = {
def loadTextFile(sqlContext: SQLContext, path: String, delimiter: String, schemaJson: String) : DataFrame = {
val stringRdd = sqlContext.sparkContext.textFile(path)
val schema = createSchema(schemaJson)
val rowRdd = stringRdd.map{s =>
val columns = s.split(delimiter)
columns.length match {
@ -217,4 +199,21 @@ object SQLUtils {
sqlContext.createDataFrame(rowRdd, schema)
}
def createSchema(schemaJson: String) : StructType = {
DataType.fromJson(schemaJson).asInstanceOf[StructType]
}
def byteArrayRDDToAnyArrayRDD(jrdd: JavaRDD[Array[Byte]]) : RDD[Array[_ >: AnyRef]] = {
// JavaRDD[Array[Byte]] -> JavaRDD[Any]
val jrddAny = SerDeUtil.pythonToJava(jrdd, true)
// JavaRDD[Any] -> RDD[Array[_]]
jrddAny.rdd.map {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}
}