diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Adapter.csproj b/csharp/Adapter/Microsoft.Spark.CSharp/Adapter.csproj index 6dccf1b..316e131 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Adapter.csproj +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Adapter.csproj @@ -76,6 +76,7 @@ + @@ -118,7 +119,7 @@ - + diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Core/RDD.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Core/RDD.cs index b36b620..0e8e8ba 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Core/RDD.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Core/RDD.cs @@ -283,7 +283,7 @@ namespace Microsoft.Spark.CSharp.Core /// public RDD Distinct(int numPartitions = 0) { - return Map(x => new KeyValuePair(x, 0)).ReduceByKey((x,y) => x, numPartitions).Map(x => x.Key); + return Map(x => new KeyValuePair(x, 0)).ReduceByKey((x, y) => x, numPartitions).Map(x => x.Key); } /// @@ -417,7 +417,7 @@ namespace Microsoft.Spark.CSharp.Core else { const double delta = 0.00005; - var gamma = - Math.Log(delta) / total; + var gamma = -Math.Log(delta) / total; return Math.Min(1, fraction + gamma + Math.Sqrt(gamma * gamma + 2 * gamma * fraction)); } } @@ -811,6 +811,7 @@ namespace Microsoft.Spark.CSharp.Core int left = num - items.Count; IEnumerable partitions = Enumerable.Range(partsScanned, Math.Min(numPartsToTry, totalParts - partsScanned)); + var mappedRDD = MapPartitionsWithIndex(new TakeHelper(left).Execute); int port = sparkContext.SparkContextProxy.RunJob(mappedRDD.RddProxy, partitions); @@ -867,7 +868,7 @@ namespace Microsoft.Spark.CSharp.Core { return Map>(v => new KeyValuePair(v, default(T))).SubtractByKey ( - other.Map>(v => new KeyValuePair(v, default(T))), + other.Map>(v => new KeyValuePair(v, default(T))), numPartitions ) .Keys(); @@ -1044,7 +1045,7 @@ namespace Microsoft.Spark.CSharp.Core /// public IEnumerable ToLocalIterator() { - foreach(int partition in Enumerable.Range(0, GetNumPartitions())) + foreach (int partition in Enumerable.Range(0, GetNumPartitions())) { var mappedRDD = MapPartitionsWithIndex((pid, iter) => iter); int port = sparkContext.SparkContextProxy.RunJob(mappedRDD.RddProxy, Enumerable.Range(partition, 1)); @@ -1382,7 +1383,7 @@ namespace Microsoft.Spark.CSharp.Core internal KeyValuePair Execute(T input) { - return new KeyValuePair(func(input), input); + return new KeyValuePair(func(input), input); } } [Serializable] @@ -1429,7 +1430,7 @@ namespace Microsoft.Spark.CSharp.Core else if (y.Value) return x; else - return new KeyValuePair(func(x.Key, y.Key), false); + return new KeyValuePair(func(x.Key, y.Key), false); } } [Serializable] diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Interop/Ipc/JsonSerDe.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Interop/Ipc/JsonSerDe.cs new file mode 100644 index 0000000..0ba0aea --- /dev/null +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Interop/Ipc/JsonSerDe.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Linq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Spark.CSharp.Interop.Ipc +{ + /// + /// Json.NET Serialization/Deserialization helper class. + /// + public static class JsonSerDe + { + // Note: Scala side uses JSortedObject when parse Json, so the properties in JObject need to be sorted + /// + /// Extend method to sort items in a JSON object by keys. + /// + /// + /// + public static JObject SortProperties(this JObject jObject) + { + JObject sortedJObject = new JObject(); + foreach (var property in jObject.Properties().OrderBy(p => p.Name)) + { + if (property.Value is JObject) + { + var propJObject = property.Value as JObject; + sortedJObject.Add(property.Name, propJObject.SortProperties()); + } + else if (property.Value is JArray) + { + var propJArray = property.Value as JArray; + sortedJObject.Add(property.Name, propJArray.SortProperties()); + } + else + { + sortedJObject.Add(property.Name, property.Value); + } + } + return sortedJObject; + } + + /// + /// Extend method to sort items in a JSON array by keys. + /// + public static JArray SortProperties(this JArray jArray) + { + JArray sortedJArray = new JArray(); + if (jArray.Count == 0) return jArray; + + foreach (var item in jArray) + { + if (item is JObject) + { + var sortedItem = ((JObject)item).SortProperties(); + sortedJArray.Add(sortedItem); + } + else if (item is JArray) + { + var sortedItem = ((JArray)item).SortProperties(); + sortedJArray.Add(sortedItem); + } + } + return sortedJArray; + } + } + +} diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISparkCLRProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISparkCLRProxy.cs index 1d01d95..c3abf77 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISparkCLRProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISparkCLRProxy.cs @@ -17,8 +17,6 @@ namespace Microsoft.Spark.CSharp.Proxy ISparkContextProxy SparkContextProxy { get; } ISparkConfProxy CreateSparkConf(bool loadDefaults = true); ISparkContextProxy CreateSparkContext(ISparkConfProxy conf); - IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable); - IStructTypeProxy CreateStructType(List fields); IDStreamProxy CreateCSharpDStream(IDStreamProxy jdstream, byte[] func, string deserializer); IDStreamProxy CreateCSharpTransformed2DStream(IDStreamProxy jdstream, IDStreamProxy jother, byte[] func, string deserializer, string deserializerOther); IDStreamProxy CreateCSharpReducedWindowedDStream(IDStreamProxy jdstream, byte[] func, byte[] invFunc, int windowSeconds, int slideSeconds, string deserializer); diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISqlContextProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISqlContextProxy.cs index e49204c..b84c86e 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISqlContextProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/ISqlContextProxy.cs @@ -13,6 +13,7 @@ namespace Microsoft.Spark.CSharp.Proxy { internal interface ISqlContextProxy { + IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy); IDataFrameProxy ReadDataFrame(string path, StructType schema, Dictionary options); IDataFrameProxy JsonFile(string path); IDataFrameProxy TextFile(string path, StructType schema, string delimiter); diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/DataFrameIpcProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/DataFrameIpcProxy.cs index d735e73..d354efa 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/DataFrameIpcProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/DataFrameIpcProxy.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Spark.CSharp.Core; using Microsoft.Spark.CSharp.Interop.Ipc; -using Microsoft.Spark.CSharp.Sql; namespace Microsoft.Spark.CSharp.Proxy.Ipc { @@ -77,7 +76,7 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc return SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod( jvmDataFrameReference, "showString", - new object[] { numberOfRows , truncate }).ToString(); + new object[] { numberOfRows, truncate }).ToString(); } public bool IsLocal() diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkCLRIpcProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkCLRIpcProxy.cs index 39d7fb9..9b521a4 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkCLRIpcProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkCLRIpcProxy.cs @@ -62,33 +62,6 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc return sparkContextProxy; } - public IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable) - { - return new StructFieldIpcProxy( - new JvmObjectReference( - JvmBridge.CallStaticJavaMethod( - "org.apache.spark.sql.api.csharp.SQLUtils", "createStructField", - new object[] { name, dataType, isNullable }).ToString() - ) - ); - } - - public IStructTypeProxy CreateStructType(List fields) - { - var fieldsReference = fields.Select(s => (s.StructFieldProxy as StructFieldIpcProxy).JvmStructFieldReference).ToList().Cast(); - - var seq = - new JvmObjectReference( - JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", - "toSeq", new object[] { fieldsReference }).ToString()); - - return new StructTypeIpcProxy( - new JvmObjectReference( - JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "createStructType", new object[] { seq }).ToString() - ) - ); - } - public IDStreamProxy CreateCSharpDStream(IDStreamProxy jdstream, byte[] func, string deserializer) { var jvmDStreamReference = SparkCLRIpcProxy.JvmBridge.CallConstructor("org.apache.spark.streaming.api.csharp.CSharpDStream", diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkContextIpcProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkContextIpcProxy.cs index 8046fd6..9bcd7ef 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkContextIpcProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SparkContextIpcProxy.cs @@ -116,7 +116,6 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc return new RDDIpcProxy(jvmRddReference); } - //TODO - this implementation is slow. Replace with call to createRDDFromArray() in CSharpRDD public IRDDProxy Parallelize(IEnumerable values, int numSlices) { var jvmRddReference = new JvmObjectReference((string)SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.api.csharp.CSharpRDD", "createRDDFromArray", new object[] { jvmSparkContextReference, values, numSlices })); diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SqlContextIpcProxy.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SqlContextIpcProxy.cs index 6638143..5e4f62a 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SqlContextIpcProxy.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Proxy/Ipc/SqlContextIpcProxy.cs @@ -21,6 +21,17 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc this.jvmSqlContextReference = jvmSqlContextReference; } + public IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy) + { + var rdd = new JvmObjectReference(SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "byteArrayRDDToAnyArrayRDD", + new object[] { (rddProxy as RDDIpcProxy).JvmRddReference }).ToString()); + + return new DataFrameIpcProxy( + new JvmObjectReference( + SparkCLRIpcProxy.JvmBridge.CallNonStaticJavaMethod(jvmSqlContextReference, "applySchemaToPythonRDD", + new object[] { rdd, (structTypeProxy as StructTypeIpcProxy).JvmStructTypeReference }).ToString()), this); + } + public IDataFrameProxy ReadDataFrame(string path, StructType schema, Dictionary options) { //TODO parameter Dictionary options is not used right now - it is meant to be passed on to data sources @@ -44,7 +55,7 @@ namespace Microsoft.Spark.CSharp.Proxy.Ipc new JvmObjectReference( SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod( "org.apache.spark.sql.api.csharp.SQLUtils", "loadTextFile", - new object[] {jvmSqlContextReference, path, delimiter, (schema.StructTypeProxy as StructTypeIpcProxy).JvmStructTypeReference}).ToString() + new object[] { jvmSqlContextReference, path, delimiter, schema.Json}).ToString() ), this ); } diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/DataFrame.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/DataFrame.cs index bda8be0..400ab0c 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/DataFrame.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/DataFrame.cs @@ -22,7 +22,7 @@ namespace Microsoft.Spark.CSharp.Sql private readonly IDataFrameProxy dataFrameProxy; [NonSerialized] private readonly SparkContext sparkContext; - [NonSerialized] + private StructType schema; [NonSerialized] private RDD rdd; @@ -40,7 +40,7 @@ namespace Microsoft.Spark.CSharp.Sql if (rdd == null) { rddProxy = dataFrameProxy.JavaToCSharp(); - rdd = new RDD(rddProxy, sparkContext, SerializedMode.Row); + rdd = new RDD(rddProxy, sparkContext, SerializedMode.Row); } return rdd; } @@ -137,7 +137,7 @@ namespace Microsoft.Spark.CSharp.Sql /// public void ShowSchema() { - List nameTypeList = Schema.Fields.Select(structField => string.Format("{0}:{1}", structField.Name, structField.DataType.SimpleString())).ToList(); + var nameTypeList = Schema.Fields.Select(structField => structField.SimpleString); Console.WriteLine(string.Join(", ", nameTypeList)); } @@ -145,18 +145,18 @@ namespace Microsoft.Spark.CSharp.Sql /// Returns all of Rows in this DataFrame /// public IEnumerable Collect() - { + { int port = RddProxy.CollectAndServe(); return Rdd.Collect(port).Cast(); } /// - /// Converts the DataFrame to RDD of byte[] + /// Converts the DataFrame to RDD of Row /// /// resulting RDD - public RDD ToRDD() //RDD created using byte representation of GenericRow objects + public RDD ToRDD() //RDD created using byte representation of Row objects { - return new RDD(dataFrameProxy.ToRDD(), sparkContext); + return Rdd; } /// diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Row.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Row.cs index e5b0e2f..04af6e4 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Row.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Row.cs @@ -2,13 +2,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Collections; using System.Collections.Generic; using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; - using Microsoft.Spark.CSharp.Services; namespace Microsoft.Spark.CSharp.Sql @@ -31,24 +27,24 @@ namespace Microsoft.Spark.CSharp.Sql /// /// Schema for the row. /// - public abstract RowSchema GetSchema(); + public abstract StructType GetSchema(); /// /// Returns the value at position i. /// - public abstract object Get(int i); + public abstract dynamic Get(int i); /// /// Returns the value of a given columnName. /// - public abstract object Get(string columnName); + public abstract dynamic Get(string columnName); /// /// Returns the value at position i, the return value will be cast to type T. /// public T GetAs(int i) { - object o = Get(i); + dynamic o = Get(i); try { T result = (T)o; @@ -66,7 +62,7 @@ namespace Microsoft.Spark.CSharp.Sql /// public T GetAs(string columnName) { - object o = Get(columnName); + dynamic o = Get(columnName); try { T result = (T)o; @@ -80,145 +76,12 @@ namespace Microsoft.Spark.CSharp.Sql } } - /// - /// Schema of Row - /// - [Serializable] - public class RowSchema - { - public string type; - public List columns; - - private readonly Dictionary columnName2Index = new Dictionary(); - - public RowSchema(string type) - { - this.type = type; - this.columns = new List(); - } - - public RowSchema(string type, List cols) - { - int index = 0; - foreach (var col in cols) - { - string columnName = col.name; - //TODO - investigate the issue and uncomment the following checks - /* - * UDFs produce empty column name. Commenting out the following code at the time of upgrading to 1.5.2 - */ - //if (string.IsNullOrEmpty(columnName)) - //{ - // throw new Exception(string.Format("Null column name at pos: {0}", index)); - //} - - //if (columnName2Index.ContainsKey(columnName)) - //{ - // throw new Exception(string.Format("duplicate column name ({0}) in pos ({1}) and ({2})", - // columnName, columnName2Index[columnName], index)); - //} - columnName2Index[columnName] = index; - index++; - } - - this.type = type; - this.columns = cols; - } - - internal int GetIndexByColumnName(string ColumnName) - { - if (!columnName2Index.ContainsKey(ColumnName)) - { - throw new Exception(string.Format("unknown ColumnName: {0}", ColumnName)); - } - - return columnName2Index[ColumnName]; - } - - public override string ToString() - { - string result; - - if (columns.Any()) - { - List cols = new List(); - foreach (var col in columns) - { - cols.Add(col.ToString()); - } - - result = "{" + - string.Format("type: {0}, columns: [{1}]", type, string.Join(", ", cols.ToArray())) + - "}"; - } - else - { - result = type; - } - - return result; - } - - internal static RowSchema ParseRowSchemaFromJson(string json) - { - JObject joType = JObject.Parse(json); - string type = joType["type"].ToString(); - - List columns = new List(); - List jtFields = joType["fields"].Children().ToList(); - foreach (JToken jtField in jtFields) - { - ColumnSchema col = ColumnSchema.ParseColumnSchemaFromJson(jtField.ToString()); - columns.Add(col); - } - - return new RowSchema(type, columns); - } - - } - - /// - /// Schema for column - /// - [Serializable] - public class ColumnSchema - { - public string name; - public RowSchema type; - public bool nullable; - - public override string ToString() - { - string str = string.Format("name: {0}, type: {1}, nullable: {2}", name, type, nullable); - return "{" + str + "}"; - } - - internal static ColumnSchema ParseColumnSchemaFromJson(string json) - { - ColumnSchema col = new ColumnSchema(); - JObject joField = JObject.Parse(json); - col.name = joField["name"].ToString(); - col.nullable = (bool)(joField["nullable"]); - - JToken jtType = joField["type"]; - if (jtType.Type == JTokenType.String) - { - col.type = new RowSchema(joField["type"].ToString()); - } - else - { - col.type = RowSchema.ParseRowSchemaFromJson(joField["type"].ToString()); - } - - return col; - } - } - [Serializable] internal class RowImpl : Row { - private readonly RowSchema schema; - private readonly object[] values; + private readonly StructType schema; + public dynamic[] Values { get { return values; } } + private readonly dynamic[] values; private readonly int columnCount; @@ -229,17 +92,15 @@ namespace Microsoft.Spark.CSharp.Sql return Get(index); } } - - - internal RowImpl(object data, RowSchema schema) + internal RowImpl(dynamic data, StructType schema) { - if (data is object[]) + if (data is dynamic[]) { - values = data as object[]; + values = data as dynamic[]; } - else if (data is List) + else if (data is List) { - values = (data as List).ToArray(); + values = (data as List).ToArray(); } else { @@ -249,7 +110,7 @@ namespace Microsoft.Spark.CSharp.Sql this.schema = schema; columnCount = values.Count(); - int schemaColumnCount = this.schema.columns.Count(); + int schemaColumnCount = this.schema.Fields.Count(); if (columnCount != schemaColumnCount) { throw new Exception(string.Format("column count inferred from data ({0}) and schema ({1}) mismatch", columnCount, schemaColumnCount)); @@ -263,12 +124,12 @@ namespace Microsoft.Spark.CSharp.Sql return columnCount; } - public override RowSchema GetSchema() + public override StructType GetSchema() { return schema; } - public override object Get(int i) + public override dynamic Get(int i) { if (i >= columnCount) { @@ -278,9 +139,9 @@ namespace Microsoft.Spark.CSharp.Sql return values[i]; } - public override object Get(string columnName) + public override dynamic Get(string columnName) { - int index = schema.GetIndexByColumnName(columnName); + int index = schema.Fields.FindIndex(f => f.Name == columnName); // case sensitive return Get(index); } @@ -305,21 +166,69 @@ namespace Microsoft.Spark.CSharp.Sql private void Initialize() { + int index = 0; - foreach (var col in schema.columns) + foreach (var field in schema.Fields) { - if (col.type.columns.Any()) // this column itself is a sub-row + if (field.DataType is ArrayType) { - object value = values[index]; + Func convertArrayTypeToStructTypeFunc = (dataType, length) => + { + StructField[] fields = new StructField[length]; + for(int i = 0; i < length ; i++) + { + fields[i] = new StructField(string.Format("_array_{0}", i), dataType); + } + return new StructType(fields); + }; + var elementType = (field.DataType as ArrayType).ElementType; + + // Note: When creating object from json, PySpark converts Json array to Python List (https://github.com/apache/spark/blob/branch-1.4/python/pyspark/sql/types.py, _create_cls(dataType)), + // then Pyrolite unpickler converts Python List to C# ArrayList (https://github.com/irmen/Pyrolite/blob/v4.10/README.txt). So values[index] should be of type ArrayList; + // In case Python changes its implementation, which means value is not of type ArrayList, try cast to object[] because Pyrolite unpickler convert Python Tuple to C# object[]. + object[] valueOfArray = values[index] is ArrayList ? (values[index] as ArrayList).ToArray() : values[index] as object[]; + if (valueOfArray == null) + { + throw new ArgumentException("Cannot parse data of ArrayType: " + field.Name); + } + + values[index] = new RowImpl(valueOfArray, elementType as StructType ?? convertArrayTypeToStructTypeFunc(elementType, valueOfArray.Length)).values; + } + else if (field.DataType is MapType) + { + //TODO + throw new NotImplementedException(); + } + else if (field.DataType is StructType) + { + dynamic value = values[index]; if (value != null) { - RowImpl subRow = new RowImpl(values[index], col.type); + var subRow = new RowImpl(values[index], field.DataType as StructType); values[index] = subRow; } } - + else if (field.DataType is DecimalType) + { + //TODO + throw new NotImplementedException(); + } + else if (field.DataType is DateType) + { + //TODO + throw new NotImplementedException(); + } + else if (field.DataType is StringType) + { + if (values[index] != null) values[index] = values[index].ToString(); + } + else + { + values[index] = values[index]; + } index++; } } } + } diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/RowConstructor.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/RowConstructor.cs index e80ccd3..bc5feb8 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/RowConstructor.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/RowConstructor.cs @@ -1,11 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using Razorvine.Pickle; namespace Microsoft.Spark.CSharp.Sql @@ -65,7 +60,8 @@ namespace Microsoft.Spark.CSharp.Sql /// public Row GetRow() { - var row = new RowImpl(GetValues(Values), RowSchema.ParseRowSchemaFromJson(Schema)); + var schema = DataType.ParseDataTypeFromJson(Schema) as StructType; + var row = new RowImpl(GetValues(Values), schema); //Resetting schema here so that rows from multiple DataFrames can be processed in the same AppDomain //next row will have schema - so resetting is fine diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/SqlContext.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/SqlContext.cs index c13eeda..93aa071 100644 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/SqlContext.cs +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/SqlContext.cs @@ -8,6 +8,7 @@ using System.Text; using System.Threading.Tasks; using Microsoft.Spark.CSharp.Core; using Microsoft.Spark.CSharp.Interop; +using Microsoft.Spark.CSharp.Interop.Ipc; using Microsoft.Spark.CSharp.Proxy; namespace Microsoft.Spark.CSharp.Sql @@ -39,9 +40,16 @@ namespace Microsoft.Spark.CSharp.Sql return new DataFrame(sqlContextProxy.ReadDataFrame(path, schema, options), sparkContext); } - public DataFrame CreateDataFrame(RDD rdd, StructType schema) + public DataFrame CreateDataFrame(RDD rdd, StructType schema) { - throw new NotImplementedException(); + // Note: This is for pickling RDD, convert to RDD which happens in CSharpWorker. + // The below sqlContextProxy.CreateDataFrame() will call byteArrayRDDToAnyArrayRDD() of SQLUtils.scala which only accept RDD of type RDD[Array[Byte]]. + // In byteArrayRDDToAnyArrayRDD() of SQLUtils.scala, the SerDeUtil.pythonToJava() will be called which is a mapPartitions inside. + // It will be executed until the CSharpWorker finishes Pickling to RDD[Array[Byte]]. + var rddRow = rdd.Map(r => r); + rddRow.serializedMode = SerializedMode.Row; + + return new DataFrame(sqlContextProxy.CreateDataFrame(rddRow.RddProxy, schema.StructTypeProxy), sparkContext); } /// diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Struct.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Struct.cs deleted file mode 100644 index e75e555..0000000 --- a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Struct.cs +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Collections.Generic; -using System.Linq; -using Microsoft.Spark.CSharp.Interop; -using Microsoft.Spark.CSharp.Proxy; -using Microsoft.Spark.CSharp.Proxy.Ipc; - -namespace Microsoft.Spark.CSharp.Sql -{ - /// - /// Schema of DataFrame - /// - public class StructType - { - private readonly IStructTypeProxy structTypeProxy; - - internal IStructTypeProxy StructTypeProxy - { - get - { - return structTypeProxy; - } - } - - public List Fields //TODO - avoid calling method everytime - { - get - { - var structTypeFieldJvmObjectReferenceList = - structTypeProxy.GetStructTypeFields(); - var structFieldList = new List(structTypeFieldJvmObjectReferenceList.Count); - structFieldList.AddRange( - structTypeFieldJvmObjectReferenceList.Select( - structTypeFieldJvmObjectReference => new StructField(structTypeFieldJvmObjectReference))); - return structFieldList; - } - } - - public string ToJson() - { - return structTypeProxy.ToJson(); - } - - - internal StructType(IStructTypeProxy structTypeProxy) - { - this.structTypeProxy = structTypeProxy; - } - - public static StructType CreateStructType(List structFields) - { - return new StructType(SparkCLREnvironment.SparkCLRProxy.CreateStructType(structFields)); - } - } - - /// - /// Schema for DataFrame column - /// - public class StructField - { - private readonly IStructFieldProxy structFieldProxy; - - internal IStructFieldProxy StructFieldProxy - { - get - { - return structFieldProxy; - } - } - - public string Name - { - get - { - return structFieldProxy.GetStructFieldName(); - } - } - - public DataType DataType - { - get - { - return new DataType(structFieldProxy.GetStructFieldDataType()); - } - } - - public bool IsNullable - { - get - { - return structFieldProxy.GetStructFieldIsNullable(); - } - } - - internal StructField(IStructFieldProxy strucFieldProxy) - { - structFieldProxy = strucFieldProxy; - } - - public static StructField CreateStructField(string name, string dataType, bool isNullable) - { - return new StructField(SparkCLREnvironment.SparkCLRProxy.CreateStructField(name, dataType, isNullable)); - } - } - - public class DataType - { - private readonly IStructDataTypeProxy structDataTypeProxy; - - internal IStructDataTypeProxy StructDataTypeProxy - { - get - { - return structDataTypeProxy; - } - } - - internal DataType(IStructDataTypeProxy structDataTypeProxy) - { - this.structDataTypeProxy = structDataTypeProxy; - } - - public override string ToString() - { - return structDataTypeProxy.GetDataTypeString(); - } - - public string SimpleString() - { - return structDataTypeProxy.GetDataTypeSimpleString(); - } - } -} \ No newline at end of file diff --git a/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Types.cs b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Types.cs new file mode 100644 index 0000000..1a34892 --- /dev/null +++ b/csharp/Adapter/Microsoft.Spark.CSharp/Sql/Types.cs @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text.RegularExpressions; +using Microsoft.Spark.CSharp.Interop.Ipc; +using Microsoft.Spark.CSharp.Proxy; +using Microsoft.Spark.CSharp.Proxy.Ipc; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Spark.CSharp.Sql +{ + [Serializable] + public abstract class DataType + { + /// + /// Trim "Type" in the end from class name, ToLower() to align with Scala. + /// + public string TypeName + { + get { return NormalizeTypeName(GetType().Name); } + } + + /// + /// return TypeName by default, subclass can override it + /// + public virtual string SimpleString + { + get { return TypeName; } + } + + /// + /// return only type: TypeName by default, subclass can override it + /// + internal virtual object JsonValue { get { return TypeName; } } + + public string Json + { + get + { + var jObject = JsonValue is JObject ? ((JObject)JsonValue).SortProperties() : JsonValue; + return JsonConvert.SerializeObject(jObject, Formatting.None); + } + } + + public static DataType ParseDataTypeFromJson(string json) + { + return ParseDataTypeFromJson(JToken.Parse(json)); + } + + protected static DataType ParseDataTypeFromJson(JToken json) + { + if (json.Type == JTokenType.Object) // {name: address, type: {type: struct,...},...} + { + JToken type; + var typeJObject = (JObject)json; + if (typeJObject.TryGetValue("type", out type)) + { + Type complexType; + if ((complexType = ComplexTypes.FirstOrDefault(ct => NormalizeTypeName(ct.Name) == type.ToString())) != default(Type)) + { + return ((ComplexType)Activator.CreateInstance(complexType, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance + , null, new object[] { typeJObject }, null)); // create new instance of ComplexType + } + if (type.ToString() == "udt") + { + // TODO + throw new NotImplementedException(); + } + } + throw new ArgumentException(string.Format("Could not parse data type: {0}", type)); + } + else // {name: age, type: bigint,...} // TODO: validate more JTokenType other than Object + { + return ParseAtomicType(json); + } + + throw new ArgumentException(string.Format("Could not parse data type: {0}", json)); + } + + private static AtomicType ParseAtomicType(JToken type) + { + Type atomicType; + if ((atomicType = AtomicTypes.FirstOrDefault(at => NormalizeTypeName(at.Name) == type.ToString())) != default(Type)) + { + return (AtomicType)Activator.CreateInstance(atomicType); // create new instance of AtomicType + } + + Match fixedDecimal = DecimalType.FixedDecimal.Match(type.ToString()); + if (fixedDecimal.Success) + { + return new DecimalType(int.Parse(fixedDecimal.Groups[1].Value), int.Parse(fixedDecimal.Groups[2].Value)); + } + + throw new ArgumentException(string.Format("Could not parse data type: {0}", type)); + } + + [NonSerialized] + private static readonly Type[] AtomicTypes = typeof(AtomicType).Assembly.GetTypes().Where(type => + type.IsSubclassOf(typeof(AtomicType))).ToArray(); + + [NonSerialized] + private static readonly Type[] ComplexTypes = typeof(ComplexType).Assembly.GetTypes().Where(type => + type.IsSubclassOf(typeof(ComplexType))).ToArray(); + + [NonSerialized] + private static readonly Func NormalizeTypeName = s => s.Substring(0, s.Length - 4).ToLower(); // trim "Type" at the end of type name + + + } + + [Serializable] + public class AtomicType : DataType + { + } + + [Serializable] + public abstract class ComplexType : DataType + { + public abstract DataType FromJson(JObject json); + public DataType FromJson(string json) + { + return FromJson(JObject.Parse(json)); + } + } + + + [Serializable] + public class NullType : AtomicType { } + + [Serializable] + public class StringType : AtomicType { } + + [Serializable] + public class BinaryType : AtomicType { } + + [Serializable] + public class BooleanType : AtomicType { } + + [Serializable] + public class DateType : AtomicType { } + + [Serializable] + public class TimestampType : AtomicType { } + + [Serializable] + public class DoubleType : AtomicType { } + + [Serializable] + public class FloatType : AtomicType { } + + [Serializable] + public class ByteType : AtomicType { } + + [Serializable] + public class IntegerType : AtomicType { } + + [Serializable] + public class LongType : AtomicType { } + + [Serializable] + public class ShortType : AtomicType { } + + [Serializable] + public class DecimalType : AtomicType + { + public static Regex FixedDecimal = new Regex(@"decimal\((\d+),\s(\d+)\)"); + private int? precision, scale; + public DecimalType(int? precision = null, int? scale = null) + { + this.precision = precision; + this.scale = scale; + } + + internal override object JsonValue + { + get { throw new NotImplementedException(); } + } + + public DataType FromJson(JObject json) + { + throw new NotImplementedException(); + } + } + + [Serializable] + public class ArrayType : ComplexType + { + public DataType ElementType { get { return elementType; } } + public bool ContainsNull { get { return containsNull; } } + + public ArrayType(DataType elementType, bool containsNull = true) + { + this.elementType = elementType; + this.containsNull = containsNull; + } + + internal ArrayType(JObject json) + { + FromJson(json); + } + + public override string SimpleString + { + get { return string.Format("array<{0}>", elementType.SimpleString); } + } + + internal override object JsonValue + { + get + { + return new JObject( + new JProperty("type", TypeName), + new JProperty("elementType", elementType.JsonValue), + new JProperty("containsNull", containsNull)); + } + } + + public override sealed DataType FromJson(JObject json) + { + elementType = ParseDataTypeFromJson(json["elementType"]); + containsNull = (bool)json["containsNull"]; + return this; + } + + private DataType elementType; + private bool containsNull; + } + + [Serializable] + public class MapType : ComplexType + { + internal override object JsonValue + { + get { throw new NotImplementedException(); } + } + + public override DataType FromJson(JObject json) + { + throw new NotImplementedException(); + } + } + + [Serializable] + public class StructField : ComplexType + { + public string Name { get { return name; } } + public DataType DataType { get { return dataType; } } + public bool IsNullable { get { return isNullable; } } + public JObject Metadata { get { return metadata; } } + + public StructField(string name, DataType dataType, bool isNullable = true, JObject metadata = null) + { + this.name = name; + this.dataType = dataType; + this.isNullable = isNullable; + this.metadata = metadata ?? new JObject(); + } + + internal StructField(JObject json) + { + FromJson(json); + } + + public override string SimpleString { get { return string.Format(@"{0}:{1}", name, dataType.SimpleString); } } + + internal override object JsonValue + { + get + { + return new JObject( + new JProperty("name", name), + new JProperty("type", dataType.JsonValue), + new JProperty("nullable", isNullable), + new JProperty("metadata", metadata)); + } + } + + public override sealed DataType FromJson(JObject json) + { + name = json["name"].ToString(); + dataType = ParseDataTypeFromJson(json["type"]); + isNullable = (bool)json["nullable"]; + metadata = (JObject)json["metadata"]; + return this; + } + + private string name; + private DataType dataType; + private bool isNullable; + [NonSerialized] + private JObject metadata; + } + + [Serializable] + public class StructType : ComplexType + { + public List Fields { get { return fields; } } + + internal IStructTypeProxy StructTypeProxy + { + get + { + return structTypeProxy ?? + new StructTypeIpcProxy( + new JvmObjectReference(SparkCLRIpcProxy.JvmBridge.CallStaticJavaMethod("org.apache.spark.sql.api.csharp.SQLUtils", "createSchema", + new object[] { Json }).ToString())); + } + } + + public StructType(IEnumerable fields) + { + this.fields = fields.ToList(); + } + + internal StructType(JObject json) + { + FromJson(json); + } + + internal StructType(IStructTypeProxy structTypeProxy) + { + this.structTypeProxy = structTypeProxy; + var jsonSchema = (structTypeProxy as StructTypeIpcProxy).ToJson(); + FromJson(jsonSchema); + } + + public override string SimpleString + { + get { return string.Format(@"struct<{0}>", string.Join(",", fields.Select(f => f.SimpleString))); } + } + + internal override object JsonValue + { + get + { + return new JObject( + new JProperty("type", TypeName), + new JProperty("fields", fields.Select(f => f.JsonValue).ToArray())); + } + } + + public override sealed DataType FromJson(JObject json) + { + var fieldsJObjects = json["fields"].Select(f => (JObject)f); + fields = fieldsJObjects.Select(fieldJObject => (new StructField(fieldJObject))).ToList(); + return this; + } + + [NonSerialized] + private readonly IStructTypeProxy structTypeProxy; + + private List fields; + } + +} diff --git a/csharp/AdapterTest/AdapterTest.csproj b/csharp/AdapterTest/AdapterTest.csproj index c4b7fe9..d14554b 100644 --- a/csharp/AdapterTest/AdapterTest.csproj +++ b/csharp/AdapterTest/AdapterTest.csproj @@ -40,6 +40,9 @@ ..\packages\Moq.4.2.1510.2205\lib\net40\Moq.dll True + + ..\packages\Newtonsoft.Json.7.0.1\lib\net45\Newtonsoft.Json.dll + ..\packages\NUnit.3.0.1\lib\net45\nunit.framework.dll True @@ -127,4 +130,4 @@ --> - + \ No newline at end of file diff --git a/csharp/AdapterTest/DataFrameTest.cs b/csharp/AdapterTest/DataFrameTest.cs index 3fb3124..53a4a5a 100644 --- a/csharp/AdapterTest/DataFrameTest.cs +++ b/csharp/AdapterTest/DataFrameTest.cs @@ -422,7 +422,7 @@ namespace AdapterTest "123", "Bill" }, - RowSchema.ParseRowSchemaFromJson(jsonSchema)) + DataType.ParseDataTypeFromJson(jsonSchema) as StructType) }; mockDataFrameProxy.Setup(m => m.JavaToCSharp()).Returns(new MockRddProxy(rows)); @@ -435,14 +435,6 @@ namespace AdapterTest Assert.IsNotNull(rdd); mockDataFrameProxy.Verify(m => m.JavaToCSharp(), Times.Once); - - mockDataFrameProxy.Reset(); - mockStructTypeProxy.Reset(); - - rdd = dataFrame.Rdd; - Assert.IsNotNull(rdd); - mockDataFrameProxy.Verify(m => m.JavaToCSharp(), Times.Never); - mockStructTypeProxy.Verify(m => m.ToJson(), Times.Never); } [Test] diff --git a/csharp/AdapterTest/Mocks/MockRow.cs b/csharp/AdapterTest/Mocks/MockRow.cs index 23a83f6..bfa5b73 100644 --- a/csharp/AdapterTest/Mocks/MockRow.cs +++ b/csharp/AdapterTest/Mocks/MockRow.cs @@ -14,7 +14,7 @@ namespace AdapterTest.Mocks throw new NotImplementedException(); } - public override RowSchema GetSchema() + public override StructType GetSchema() { throw new NotImplementedException(); } diff --git a/csharp/AdapterTest/Mocks/MockSparkCLRProxy.cs b/csharp/AdapterTest/Mocks/MockSparkCLRProxy.cs index 72b13a9..07016b4 100644 --- a/csharp/AdapterTest/Mocks/MockSparkCLRProxy.cs +++ b/csharp/AdapterTest/Mocks/MockSparkCLRProxy.cs @@ -44,17 +44,6 @@ namespace AdapterTest.Mocks return new MockSparkContextProxy(conf); } - - public IStructFieldProxy CreateStructField(string name, string dataType, bool isNullable) - { - throw new NotImplementedException(); - } - - public IStructTypeProxy CreateStructType(List fields) - { - throw new NotImplementedException(); - } - public ISparkContextProxy SparkContextProxy { get { throw new NotImplementedException(); } diff --git a/csharp/AdapterTest/Mocks/MockSqlContextProxy.cs b/csharp/AdapterTest/Mocks/MockSqlContextProxy.cs index 4366410..cb74842 100644 --- a/csharp/AdapterTest/Mocks/MockSqlContextProxy.cs +++ b/csharp/AdapterTest/Mocks/MockSqlContextProxy.cs @@ -27,6 +27,11 @@ namespace AdapterTest.Mocks mockSparkContextProxy = scProxy; } + public IDataFrameProxy CreateDataFrame(IRDDProxy rddProxy, IStructTypeProxy structTypeProxy) + { + throw new NotImplementedException(); + } + public IDataFrameProxy ReadDataFrame(string path, StructType schema, System.Collections.Generic.Dictionary options) { throw new NotImplementedException(); diff --git a/csharp/AdapterTest/packages.config b/csharp/AdapterTest/packages.config index 449d196..32cf3d2 100644 --- a/csharp/AdapterTest/packages.config +++ b/csharp/AdapterTest/packages.config @@ -1,6 +1,7 @@  + diff --git a/csharp/Samples/Microsoft.Spark.CSharp/DataFrameSamples.cs b/csharp/Samples/Microsoft.Spark.CSharp/DataFrameSamples.cs index e040395..f911086 100644 --- a/csharp/Samples/Microsoft.Spark.CSharp/DataFrameSamples.cs +++ b/csharp/Samples/Microsoft.Spark.CSharp/DataFrameSamples.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.IO; using System.Linq; +using System.Text.RegularExpressions; using Microsoft.Spark.CSharp.Core; using Microsoft.Spark.CSharp.Sql; using NUnit.Framework; @@ -27,6 +28,116 @@ namespace Microsoft.Spark.CSharp.Samples return sqlContext ?? (sqlContext = new SqlContext(SparkCLRSamples.SparkContext)); } + /// + /// Sample to create DataFrame. The RDD is generated from SparkContext Parallelize; the schema is created via object creating. + /// + [Sample] + internal static void DFCreateDataFrameSample() + { + var schemaPeople = new StructType(new List + { + new StructField("id", new StringType()), + new StructField("name", new StringType()), + new StructField("age", new IntegerType()), + new StructField("address", new StructType(new List + { + new StructField("city", new StringType()), + new StructField("state", new StringType()) + })), + new StructField("phone numbers", new ArrayType(new StringType())) + }); + + var rddPeople = SparkCLRSamples.SparkContext.Parallelize( + new List + { + new object[] { "123", "Bill", 43, new object[]{ "Columbus", "Ohio" }, new string[]{ "Tel1", "Tel2" } }, + new object[] { "456", "Steve", 34, new object[]{ "Seattle", "Washington" }, new string[]{ "Tel3", "Tel4" } } + }); + + var dataFramePeople = GetSqlContext().CreateDataFrame(rddPeople, schemaPeople); + Console.WriteLine("------ Schema of People Data Frame:\r\n"); + dataFramePeople.ShowSchema(); + Console.WriteLine(); + var collected = dataFramePeople.Collect().ToArray(); + foreach (var people in collected) + { + string id = people.Get("id"); + string name = people.Get("name"); + int age = people.Get("age"); + Row address = people.Get("address"); + string city = address.Get("city"); + string state = address.Get("state"); + object[] phoneNumbers = people.Get("phone numbers"); + Console.WriteLine("id:{0}, name:{1}, age:{2}, address:(city:{3},state:{4}), phoneNumbers:[{5},{6}]\r\n", id, name, age, city, state, phoneNumbers[0], phoneNumbers[1]); + } + + if (SparkCLRSamples.Configuration.IsValidationEnabled) + { + Assert.AreEqual(2, dataFramePeople.Rdd.Count()); + Assert.AreEqual(schemaPeople.Json, dataFramePeople.Schema.Json); + } + } + + /// + /// Sample to create DataFrame. The RDD is generated from SparkContext TextFile; the schema is created from Json. + /// + [Sample] + internal static void DFCreateDataFrameSample2() + { + var rddRequestsLog = SparkCLRSamples.SparkContext.TextFile(SparkCLRSamples.Configuration.GetInputDataPath(RequestsLog), 1).Map(r => r.Split(',').Select(s => (object)s).ToArray()); + + const string schemaRequestsLogJson = @"{ + ""fields"": [{ + ""metadata"": {}, + ""name"": ""guid"", + ""nullable"": false, + ""type"": ""string"" + }, + { + ""metadata"": {}, + ""name"": ""datacenter"", + ""nullable"": false, + ""type"": ""string"" + }, + { + ""metadata"": {}, + ""name"": ""abtestid"", + ""nullable"": false, + ""type"": ""string"" + }, + { + ""metadata"": {}, + ""name"": ""traffictype"", + ""nullable"": false, + ""type"": ""string"" + }], + ""type"": ""struct"" + }"; + + // create schema from parsing Json + StructType requestsLogSchema = DataType.ParseDataTypeFromJson(schemaRequestsLogJson) as StructType; + var dataFrameRequestsLog = GetSqlContext().CreateDataFrame(rddRequestsLog, requestsLogSchema); + + Console.WriteLine("------ Schema of RequestsLog Data Frame:"); + dataFrameRequestsLog.ShowSchema(); + Console.WriteLine(); + var collected = dataFrameRequestsLog.Collect().ToArray(); + foreach (var request in collected) + { + string guid = request.Get("guid"); + string datacenter = request.Get("datacenter"); + string abtestid = request.Get("abtestid"); + string traffictype = request.Get("traffictype"); + Console.WriteLine("guid:{0}, datacenter:{1}, abtestid:{2}, traffictype:{3}\r\n", guid, datacenter, abtestid, traffictype); + } + + if (SparkCLRSamples.Configuration.IsValidationEnabled) + { + Assert.AreEqual(10, collected.Length); + Assert.AreEqual(Regex.Replace(schemaRequestsLogJson, @"\s", string.Empty), Regex.Replace(dataFrameRequestsLog.Schema.Json, @"\s", string.Empty)); + } + } + /// /// Sample to show schema of DataFrame /// @@ -42,10 +153,10 @@ namespace Microsoft.Spark.CSharp.Samples /// Sample to get schema of DataFrame in json format /// [Sample] - internal static void DFGetSchemaToJsonSample() + internal static void DFGetSchemaJsonSample() { var peopleDataFrame = GetSqlContext().JsonFile(SparkCLRSamples.Configuration.GetInputDataPath(PeopleJson)); - string json = peopleDataFrame.Schema.ToJson(); + string json = peopleDataFrame.Schema.Json; Console.WriteLine("schema in json format: {0}", json); } @@ -56,7 +167,7 @@ namespace Microsoft.Spark.CSharp.Samples internal static void DFCollectSample() { var peopleDataFrame = GetSqlContext().JsonFile(SparkCLRSamples.Configuration.GetInputDataPath(PeopleJson)); - IEnumerable rows = peopleDataFrame.Collect(); + var rows = peopleDataFrame.Collect().ToArray(); Console.WriteLine("peopleDataFrame:"); int i = 0; @@ -122,15 +233,13 @@ namespace Microsoft.Spark.CSharp.Samples [Sample] internal static void DFTextFileLoadDataFrameSample() { - var requestsSchema = StructType.CreateStructType( - new List - { - StructField.CreateStructField("guid", "string", false), - StructField.CreateStructField("datacenter", "string", false), - StructField.CreateStructField("abtestid", "string", false), - StructField.CreateStructField("traffictype", "string", false), - } - ); + var requestsSchema = new StructType(new List + { + new StructField("guid", new StringType(), false), + new StructField("datacenter", new StringType(), false), + new StructField("abtestid", new StringType(), false), + new StructField("traffictype", new StringType(), false), + }); var requestsDateFrame = GetSqlContext().TextFile(SparkCLRSamples.Configuration.GetInputDataPath(RequestsLog), requestsSchema); requestsDateFrame.RegisterTempTable("requests"); @@ -154,18 +263,17 @@ namespace Microsoft.Spark.CSharp.Samples private static DataFrame GetMetricsDataFrame() { - var metricsSchema = StructType.CreateStructType( - new List + var metricsSchema = new StructType( + new[] { - StructField.CreateStructField("unknown", "string", false), - StructField.CreateStructField("date", "string", false), - StructField.CreateStructField("time", "string", false), - StructField.CreateStructField("guid", "string", false), - StructField.CreateStructField("lang", "string", false), - StructField.CreateStructField("country", "string", false), - StructField.CreateStructField("latency", "integer", false) - } - ); + new StructField("unknown", new StringType(), false), + new StructField("date", new StringType(), false), + new StructField("time", new StringType(), false), + new StructField("guid", new StringType(), false), + new StructField("lang", new StringType(), false), + new StructField("country", new StringType(), false), + new StructField("latency", new StringType(), false), + }); return GetSqlContext() @@ -236,8 +344,8 @@ namespace Microsoft.Spark.CSharp.Samples { var name = peopleDataFrameSchemaField.Name; var dataType = peopleDataFrameSchemaField.DataType; - var stringVal = dataType.ToString(); - var simpleStringVal = dataType.SimpleString(); + var stringVal = dataType.TypeName; + var simpleStringVal = dataType.SimpleString; var isNullable = peopleDataFrameSchemaField.IsNullable; Console.WriteLine("Name={0}, DT.string={1}, DT.simplestring={2}, DT.isNullable={3}", name, stringVal, simpleStringVal, isNullable); } @@ -388,7 +496,7 @@ namespace Microsoft.Spark.CSharp.Samples var singleValueReplaced = peopleDataFrame.Replace("Bill", "Bill.G"); singleValueReplaced.Show(); - + var multiValueReplaced = peopleDataFrame.ReplaceAll(new List { 14, 34 }, new List { 44, 54 }); multiValueReplaced.Show(); @@ -853,7 +961,7 @@ namespace Microsoft.Spark.CSharp.Samples Console.WriteLine("peopleDataFrame:"); var count = 0; - RowSchema schema = null; + StructType schema = null; Row firstRow = null; foreach (var row in rows) { @@ -939,43 +1047,43 @@ namespace Microsoft.Spark.CSharp.Samples /// Verify the schema of people dataframe. /// /// RowSchema of people DataFrame - internal static void VerifySchemaOfPeopleDataFrame(RowSchema schema) + internal static void VerifySchemaOfPeopleDataFrame(StructType schema) { Assert.IsNotNull(schema); - Assert.AreEqual("struct", schema.type); - Assert.IsNotNull(schema.columns); - Assert.AreEqual(4, schema.columns.Count); + Assert.AreEqual("struct", schema.TypeName); + Assert.IsNotNull(schema.Fields); + Assert.AreEqual(4, schema.Fields.Count); // name - var nameColSchema = schema.columns.Find(c => c.name.Equals("name")); + var nameColSchema = schema.Fields.Find(c => c.Name.Equals("name")); Assert.IsNotNull(nameColSchema); - Assert.AreEqual("name", nameColSchema.name); - Assert.IsTrue(nameColSchema.nullable); - Assert.AreEqual("string", nameColSchema.type.ToString()); + Assert.AreEqual("name", nameColSchema.Name); + Assert.IsTrue(nameColSchema.IsNullable); + Assert.AreEqual("string", nameColSchema.DataType.TypeName); // id - var idColSchema = schema.columns.Find(c => c.name.Equals("id")); + var idColSchema = schema.Fields.Find(c => c.Name.Equals("id")); Assert.IsNotNull(idColSchema); - Assert.AreEqual("id", idColSchema.name); - Assert.IsTrue(idColSchema.nullable); - Assert.AreEqual("string", nameColSchema.type.ToString()); + Assert.AreEqual("id", idColSchema.Name); + Assert.IsTrue(idColSchema.IsNullable); + Assert.AreEqual("string", nameColSchema.DataType.TypeName); // age - var ageColSchema = schema.columns.Find(c => c.name.Equals("age")); + var ageColSchema = schema.Fields.Find(c => c.Name.Equals("age")); Assert.IsNotNull(ageColSchema); - Assert.AreEqual("age", ageColSchema.name); - Assert.IsTrue(ageColSchema.nullable); - Assert.AreEqual("long", ageColSchema.type.ToString()); + Assert.AreEqual("age", ageColSchema.Name); + Assert.IsTrue(ageColSchema.IsNullable); + Assert.AreEqual("long", ageColSchema.DataType.TypeName); // address - var addressColSchema = schema.columns.Find(c => c.name.Equals("address")); + var addressColSchema = schema.Fields.Find(c => c.Name.Equals("address")); Assert.IsNotNull(addressColSchema); - Assert.AreEqual("address", addressColSchema.name); - Assert.IsTrue(addressColSchema.nullable); - Assert.IsNotNull(addressColSchema.type); - Assert.AreEqual("struct", addressColSchema.type.type); - Assert.IsNotNull(addressColSchema.type.columns.Find(c => c.name.Equals("state"))); - Assert.IsNotNull(addressColSchema.type.columns.Find(c => c.name.Equals("city"))); + Assert.AreEqual("address", addressColSchema.Name); + Assert.IsTrue(addressColSchema.IsNullable); + Assert.IsNotNull(addressColSchema.DataType); + Assert.AreEqual("struct", addressColSchema.DataType.TypeName); + Assert.IsNotNull(((StructType)addressColSchema.DataType).Fields.Find(c => c.Name.Equals("state"))); + Assert.IsNotNull(((StructType)addressColSchema.DataType).Fields.Find(c => c.Name.Equals("city"))); } /// @@ -1128,25 +1236,25 @@ namespace Microsoft.Spark.CSharp.Samples if (x == null && y == null) return true; if (x == null && y != null || x != null && y == null) return false; - foreach (var col in x.GetSchema().columns) + foreach (var col in x.GetSchema().Fields) { - if (!y.GetSchema().columns.Any(c => c.ToString() == col.ToString())) return false; + if (!y.GetSchema().Fields.Any(c => c.Name == col.Name)) return false; - if (col.type.columns.Any()) + if (col.DataType is StructType) { - if (!IsRowEqual(x.GetAs(col.name), y.GetAs(col.name), columnsComparer)) return false; + if (!IsRowEqual(x.GetAs(col.Name), y.GetAs(col.Name), columnsComparer)) return false; } - else if (x.Get(col.name) == null && y.Get(col.name) != null || x.Get(col.name) != null && y.Get(col.name) == null) return false; - else if (x.Get(col.name) != null && y.Get(col.name) != null) + else if (x.Get(col.Name) == null && y.Get(col.Name) != null || x.Get(col.Name) != null && y.Get(col.Name) == null) return false; + else if (x.Get(col.Name) != null && y.Get(col.Name) != null) { Func colComparer; - if (columnsComparer != null && columnsComparer.TryGetValue(col.name, out colComparer)) + if (columnsComparer != null && columnsComparer.TryGetValue(col.Name, out colComparer)) { - if (!colComparer(x.Get(col.name), y.Get(col.name))) return false; + if (!colComparer(x.Get(col.Name), y.Get(col.Name))) return false; } else { - if (x.Get(col.name).ToString() != y.Get(col.name).ToString()) return false; + if (x.Get(col.Name).ToString() != y.Get(col.Name).ToString()) return false; } } } diff --git a/csharp/Samples/Microsoft.Spark.CSharp/Samples.csproj b/csharp/Samples/Microsoft.Spark.CSharp/Samples.csproj index 1629841..8b64bcf 100644 --- a/csharp/Samples/Microsoft.Spark.CSharp/Samples.csproj +++ b/csharp/Samples/Microsoft.Spark.CSharp/Samples.csproj @@ -33,6 +33,10 @@ 4 + + ..\..\packages\Newtonsoft.Json.7.0.1\lib\net45\Newtonsoft.Json.dll + True + ..\..\packages\NUnit.3.0.1\lib\net45\nunit.framework.dll True diff --git a/csharp/Samples/Microsoft.Spark.CSharp/packages.config b/csharp/Samples/Microsoft.Spark.CSharp/packages.config index b183023..4abe7e9 100644 --- a/csharp/Samples/Microsoft.Spark.CSharp/packages.config +++ b/csharp/Samples/Microsoft.Spark.CSharp/packages.config @@ -1,4 +1,5 @@  + \ No newline at end of file diff --git a/scala/src/main/org/apache/spark/sql/api/csharp/SQLUtils.scala b/scala/src/main/org/apache/spark/sql/api/csharp/SQLUtils.scala index 6d9a426..74bfe1e 100644 --- a/scala/src/main/org/apache/spark/sql/api/csharp/SQLUtils.scala +++ b/scala/src/main/org/apache/spark/sql/api/csharp/SQLUtils.scala @@ -3,14 +3,16 @@ package org.apache.spark.sql.api.csharp -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{ByteArrayOutputStream, DataOutputStream} import org.apache.spark.SparkContext import org.apache.spark.api.csharp.SerDe import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DataType, FloatType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, FloatType, StructType} import org.apache.spark.sql._ +import java.util.{ArrayList => JArrayList} /** * Utility functions for DataFrame in SparkCLR @@ -31,10 +33,6 @@ object SQLUtils { arr.toSeq } - def createStructType(fields : Seq[StructField]): StructType = { - StructType(fields) - } - def getSQLDataType(dataType: String): DataType = { dataType match { case "byte" => org.apache.spark.sql.types.ByteType @@ -54,17 +52,6 @@ object SQLUtils { } } - def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { - val dtObj = getSQLDataType(dataType) - StructField(name, dtObj, nullable) - } - - def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { - val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow(_, schema)) - sqlContext.createDataFrame(rowRDD, schema) - } - def dfToRowRDD(df: DataFrame): RDD[Array[Byte]] = { df.map(r => rowToCSharpBytes(r)) } @@ -77,14 +64,6 @@ object SQLUtils { } } - private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { - val bis = new ByteArrayInputStream(bytes) - val dis = new DataInputStream(bis) - val num = SerDe.readInt(dis) - Row.fromSeq((0 until num).map { i => - doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }.toSeq) - } private[this] def rowToCSharpBytes(row: Row): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -176,8 +155,11 @@ object SQLUtils { dfReader.load(path) } - def loadTextFile(sqlContext: SQLContext, path: String, delimiter: String, schema: StructType) : DataFrame = { + def loadTextFile(sqlContext: SQLContext, path: String, delimiter: String, schemaJson: String) : DataFrame = { val stringRdd = sqlContext.sparkContext.textFile(path) + + val schema = createSchema(schemaJson) + val rowRdd = stringRdd.map{s => val columns = s.split(delimiter) columns.length match { @@ -217,4 +199,21 @@ object SQLUtils { sqlContext.createDataFrame(rowRdd, schema) } + + def createSchema(schemaJson: String) : StructType = { + DataType.fromJson(schemaJson).asInstanceOf[StructType] + } + + def byteArrayRDDToAnyArrayRDD(jrdd: JavaRDD[Array[Byte]]) : RDD[Array[_ >: AnyRef]] = { + // JavaRDD[Array[Byte]] -> JavaRDD[Any] + val jrddAny = SerDeUtil.pythonToJava(jrdd, true) + + // JavaRDD[Any] -> RDD[Array[_]] + jrddAny.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + } + } }