From b2fa3508d79dc16a9340175850e29ffe598de2a6 Mon Sep 17 00:00:00 2001
From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Date: Fri, 15 Apr 2022 01:28:51 +0800
Subject: [PATCH] feat: add base classes for ML and refine code base (#1031)
---
.../IpcTests/ML/Feature/FeatureBaseTests.cs | 6 +-
.../IpcTests/ML/Feature/PipelineModelTests.cs | 79 ++++++++++++
.../IpcTests/ML/Feature/PipelineTests.cs | 111 ++++++++++++++++
.../Internal/Dotnet/ArrayExtensions.cs | 33 +++++
.../Internal/Dotnet/DictionaryExtensions.cs | 44 +++++++
.../Interop/Internal/Java/Util/HashMap.cs | 60 +++++++++
.../ML/Feature/{FeatureBase.cs => Base.cs} | 112 ++++++++++------
.../Microsoft.Spark/ML/Feature/Bucketizer.cs | 61 ++++++---
.../ML/Feature/CountVectorizer.cs | 64 +++++++---
.../ML/Feature/CountVectorizerModel.cs | 66 +++++++---
.../Microsoft.Spark/ML/Feature/Estimator.cs | 45 +++++++
.../Microsoft.Spark/ML/Feature/Evaluator.cs | 45 +++++++
.../ML/Feature/FeatureHasher.cs | 65 +++++++---
.../Microsoft.Spark/ML/Feature/HashingTF.cs | 45 +++++--
src/csharp/Microsoft.Spark/ML/Feature/IDF.cs | 40 ++++--
.../Microsoft.Spark/ML/Feature/IDFModel.cs | 42 ++++--
.../Microsoft.Spark/ML/Feature/Model.cs | 53 ++++++++
.../Microsoft.Spark/ML/Feature/NGram.cs | 30 ++++-
.../Microsoft.Spark/ML/Feature/Pipeline.cs | 120 ++++++++++++++++++
.../ML/Feature/PipelineModel.cs | 70 ++++++++++
.../ML/Feature/PipelineStage.cs | 50 ++++++++
.../ML/Feature/SQLTransformer.cs | 38 +++++-
.../ML/Feature/StopWordsRemover.cs | 30 ++++-
.../Microsoft.Spark/ML/Feature/Tokenizer.cs | 44 +++++--
.../Microsoft.Spark/ML/Feature/Transformer.cs | 37 ++++++
.../Microsoft.Spark/ML/Feature/Word2Vec.cs | 68 ++++++----
.../ML/Feature/Word2VecModel.cs | 42 ++++--
src/csharp/Microsoft.Spark/ML/Param/Param.cs | 6 +-
.../Microsoft.Spark/ML/Param/ParamMap.cs | 46 +++++++
.../Microsoft.Spark/ML/Param/ParamPair.cs | 29 +++++
src/csharp/Microsoft.Spark/ML/Util/Read.cs | 67 ++++++++++
src/csharp/Microsoft.Spark/ML/Util/Write.cs | 73 +++++++++++
src/scala/microsoft-spark-2-4/pom.xml | 6 +
.../apache/spark/api/dotnet/DotnetUtils.scala | 39 ++++++
.../spark/mllib/api/dotnet/MLUtils.scala | 26 ++++
src/scala/microsoft-spark-3-0/pom.xml | 6 +
.../apache/spark/api/dotnet/DotnetUtils.scala | 39 ++++++
.../spark/mllib/api/dotnet/MLUtils.scala | 26 ++++
src/scala/microsoft-spark-3-1/pom.xml | 6 +
.../apache/spark/api/dotnet/DotnetUtils.scala | 39 ++++++
.../spark/mllib/api/dotnet/MLUtils.scala | 26 ++++
src/scala/microsoft-spark-3-2/pom.xml | 6 +
.../apache/spark/api/dotnet/DotnetUtils.scala | 39 ++++++
.../spark/mllib/api/dotnet/MLUtils.scala | 26 ++++
44 files changed, 1793 insertions(+), 212 deletions(-)
create mode 100644 src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineModelTests.cs
create mode 100644 src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineTests.cs
create mode 100644 src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/ArrayExtensions.cs
create mode 100644 src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/DictionaryExtensions.cs
create mode 100644 src/csharp/Microsoft.Spark/Interop/Internal/Java/Util/HashMap.cs
rename src/csharp/Microsoft.Spark/ML/Feature/{FeatureBase.cs => Base.cs} (51%)
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/Estimator.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/Evaluator.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/Model.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/PipelineModel.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/PipelineStage.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Feature/Transformer.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Param/ParamMap.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Param/ParamPair.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Util/Read.cs
create mode 100644 src/csharp/Microsoft.Spark/ML/Util/Write.cs
create mode 100644 src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
create mode 100644 src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
create mode 100644 src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
create mode 100644 src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
create mode 100644 src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
create mode 100644 src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
create mode 100644 src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
create mode 100644 src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs
index 01903e51..0f9be766 100644
--- a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/FeatureBaseTests.cs
@@ -25,7 +25,7 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
/// The name of a parameter that can be set on this object
/// A parameter value that can be set on this object
public void TestFeatureBase(
- FeatureBase testObject,
+ Params testObject,
string paramName,
object paramValue)
{
@@ -37,8 +37,8 @@ namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
Assert.Equal(param.Parent, testObject.Uid());
Assert.NotEmpty(testObject.ExplainParam(param));
- testObject.Set(param, paramValue);
- Assert.IsAssignableFrom(testObject.Clear(param));
+ testObject.Set(param, paramValue);
+ Assert.IsAssignableFrom(testObject.Clear(param));
Assert.IsType(testObject.Uid());
}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineModelTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineModelTests.cs
new file mode 100644
index 00000000..7434d055
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineModelTests.cs
@@ -0,0 +1,79 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+using Microsoft.Spark.ML.Feature;
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
+using Microsoft.Spark.Sql.Types;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
+{
+ [Collection("Spark E2E Tests")]
+ public class PipelineModelTests : FeatureBaseTests
+ {
+ private readonly SparkSession _spark;
+
+ public PipelineModelTests(SparkFixture fixture) : base(fixture)
+ {
+ _spark = fixture.Spark;
+ }
+
+ ///
+ /// Create a and test the
+ /// available methods.
+ ///
+ [Fact]
+ public void TestPipelineModelTransform()
+ {
+ var expectedSplits =
+ new double[] { double.MinValue, 0.0, 10.0, 50.0, double.MaxValue };
+
+ string expectedHandle = "skip";
+ string expectedUid = "uid";
+ string expectedInputCol = "input_col";
+ string expectedOutputCol = "output_col";
+
+ var bucketizer = new Bucketizer(expectedUid);
+ bucketizer.SetInputCol(expectedInputCol)
+ .SetOutputCol(expectedOutputCol)
+ .SetHandleInvalid(expectedHandle)
+ .SetSplits(expectedSplits);
+
+ var stages = new JavaTransformer[] {
+ bucketizer
+ };
+
+ PipelineModel pipelineModel = new PipelineModel("randomUID", stages);
+
+ DataFrame input = _spark.Sql("SELECT ID as input_col from range(100)");
+
+ DataFrame output = pipelineModel.Transform(input);
+ Assert.Contains(output.Schema().Fields, (f => f.Name == expectedOutputCol));
+
+ Assert.Equal(expectedInputCol, bucketizer.GetInputCol());
+ Assert.Equal(expectedOutputCol, bucketizer.GetOutputCol());
+ Assert.Equal(expectedSplits, bucketizer.GetSplits());
+
+ Assert.IsType(pipelineModel.TransformSchema(input.Schema()));
+ Assert.IsType(output);
+
+ using (var tempDirectory = new TemporaryDirectory())
+ {
+ string savePath = Path.Join(tempDirectory.Path, "pipelineModel");
+ pipelineModel.Save(savePath);
+
+ PipelineModel loadedPipelineModel = PipelineModel.Load(savePath);
+ Assert.Equal(pipelineModel.Uid(), loadedPipelineModel.Uid());
+
+ string writePath = Path.Join(tempDirectory.Path, "pipelineModelWithWrite");
+ pipelineModel.Write().Save(writePath);
+
+ PipelineModel loadedPipelineModelWithRead = pipelineModel.Read().Load(writePath);
+ Assert.Equal(pipelineModel.Uid(), loadedPipelineModelWithRead.Uid());
+ }
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineTests.cs
new file mode 100644
index 00000000..3a07335d
--- /dev/null
+++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/PipelineTests.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+using Microsoft.Spark.ML.Feature;
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.UnitTest.TestUtils;
+using Microsoft.Spark.Sql.Types;
+using Xunit;
+
+namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
+{
+ [Collection("Spark E2E Tests")]
+ public class PipelineTests : FeatureBaseTests
+ {
+ private readonly SparkSession _spark;
+
+ public PipelineTests(SparkFixture fixture) : base(fixture)
+ {
+ _spark = fixture.Spark;
+ }
+
+ ///
+ /// Create a and test the
+ /// available methods. Test the FeatureBase methods
+ /// using .
+ ///
+ [Fact]
+ public void TestPipeline()
+ {
+ var stages = new JavaPipelineStage[] {
+ new Bucketizer(),
+ new CountVectorizer()
+ };
+
+ Pipeline pipeline = new Pipeline()
+ .SetStages(stages);
+ JavaPipelineStage[] returnStages = pipeline.GetStages();
+
+ Assert.Equal(stages[0].Uid(), returnStages[0].Uid());
+ Assert.Equal(stages[0].ToString(), returnStages[0].ToString());
+ Assert.Equal(stages[1].Uid(), returnStages[1].Uid());
+ Assert.Equal(stages[1].ToString(), returnStages[1].ToString());
+
+ using (var tempDirectory = new TemporaryDirectory())
+ {
+ string savePath = Path.Join(tempDirectory.Path, "pipeline");
+ pipeline.Save(savePath);
+
+ Pipeline loadedPipeline = Pipeline.Load(savePath);
+ Assert.Equal(pipeline.Uid(), loadedPipeline.Uid());
+ }
+
+ TestFeatureBase(pipeline, "stages", stages);
+ }
+
+ ///
+ /// Create a and test the
+ /// fit and read/write methods.
+ ///
+ [Fact]
+ public void TestPipelineFit()
+ {
+ DataFrame input = _spark.Sql("SELECT array('hello', 'I', 'AM', 'a', 'string', 'TO', " +
+ "'TOKENIZE') as input from range(100)");
+
+ const string inputColumn = "input";
+ const string outputColumn = "output";
+ const double minDf = 1;
+ const double minTf = 10;
+ const int vocabSize = 10000;
+
+ CountVectorizer countVectorizer = new CountVectorizer()
+ .SetInputCol(inputColumn)
+ .SetOutputCol(outputColumn)
+ .SetMinDF(minDf)
+ .SetMinTF(minTf)
+ .SetVocabSize(vocabSize);
+
+ var stages = new JavaPipelineStage[] {
+ countVectorizer
+ };
+
+ Pipeline pipeline = new Pipeline().SetStages(stages);
+ PipelineModel pipelineModel = pipeline.Fit(input);
+
+ DataFrame output = pipelineModel.Transform(input);
+
+ Assert.IsType(pipelineModel.TransformSchema(input.Schema()));
+ Assert.IsType(output);
+
+ using (var tempDirectory = new TemporaryDirectory())
+ {
+ string savePath = Path.Join(tempDirectory.Path, "pipeline");
+ pipeline.Save(savePath);
+
+ Pipeline loadedPipeline = Pipeline.Load(savePath);
+ Assert.Equal(pipeline.Uid(), loadedPipeline.Uid());
+
+ string writePath = Path.Join(tempDirectory.Path, "pipelineWithWrite");
+ pipeline.Write().Save(writePath);
+
+ Pipeline loadedPipelineWithRead = pipeline.Read().Load(writePath);
+ Assert.Equal(pipeline.Uid(), loadedPipelineWithRead.Uid());
+ }
+
+ TestFeatureBase(pipeline, "stages", stages);
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/ArrayExtensions.cs b/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/ArrayExtensions.cs
new file mode 100644
index 00000000..fb59a5ec
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/ArrayExtensions.cs
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Internal.Java.Util;
+
+namespace System
+{
+ ///
+ /// ArrayExtensions host custom extension methods for the
+ /// dotnet base class array T[].
+ ///
+ public static class ArrayExtensions
+ {
+ ///
+ /// A custom extension method that helps transform from dotnet
+ /// array of type T to java.util.ArrayList.
+ ///
+ /// an array instance
+ /// elements type of param array
+ ///
+ internal static ArrayList ToJavaArrayList(this T[] array)
+ {
+ var arrayList = new ArrayList(SparkEnvironment.JvmBridge);
+ foreach (T item in array)
+ {
+ arrayList.Add(item);
+ }
+ return arrayList;
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/DictionaryExtensions.cs b/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/DictionaryExtensions.cs
new file mode 100644
index 00000000..7300eb51
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/Interop/Internal/Dotnet/DictionaryExtensions.cs
@@ -0,0 +1,44 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Internal.Java.Util;
+
+namespace System.Collections.Generic
+{
+ public static class Dictionary
+ {
+ ///
+ /// A custom extension method that helps transform from dotnet
+ /// Dictionary<string, string> to java.util.HashMap.
+ ///
+ /// a Dictionary instance
+ ///
+ internal static HashMap ToJavaHashMap(this Dictionary dictionary)
+ {
+ var hashMap = new HashMap(SparkEnvironment.JvmBridge);
+ foreach (KeyValuePair item in dictionary)
+ {
+ hashMap.Put(item.Key, item.Value);
+ }
+ return hashMap;
+ }
+
+ ///
+ /// A custom extension method that helps transform from dotnet
+ /// Dictionary<string, object> to java.util.HashMap.
+ ///
+ /// a Dictionary instance
+ ///
+ internal static HashMap ToJavaHashMap(this Dictionary dictionary)
+ {
+ var hashMap = new HashMap(SparkEnvironment.JvmBridge);
+ foreach (KeyValuePair item in dictionary)
+ {
+ hashMap.Put(item.Key, item.Value);
+ }
+ return hashMap;
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/Interop/Internal/Java/Util/HashMap.cs b/src/csharp/Microsoft.Spark/Interop/Internal/Java/Util/HashMap.cs
new file mode 100644
index 00000000..fb992d2b
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/Interop/Internal/Java/Util/HashMap.cs
@@ -0,0 +1,60 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.Interop.Internal.Java.Util
+{
+ ///
+ /// HashMap class represents a java.util.HashMap object.
+ ///
+ internal sealed class HashMap : IJvmObjectReferenceProvider
+ {
+ ///
+ /// Create a java.util.HashMap JVM object
+ ///
+ /// JVM bridge to use
+ internal HashMap(IJvmBridge jvm) =>
+ Reference = jvm.CallConstructor("java.util.HashMap");
+
+ public JvmObjectReference Reference { get; private set; }
+
+ ///
+ /// Associates the specified value with the specified key in this map.
+ /// If the map previously contained a mapping for the key, the old value is replaced.
+ ///
+ /// key with which the specified value is to be associated
+ /// value to be associated with the specified key
+ internal void Put(object key, object value) =>
+ Reference.Invoke("put", key, value);
+
+ ///
+ /// Returns the value to which the specified key is mapped,
+ /// or null if this map contains no mapping for the key.
+ ///
+ /// value whose presence in this map is to be tested
+ /// value associated with the specified key
+ internal object Get(object key) =>
+ Reference.Invoke("get", key);
+
+ ///
+ /// Returns true if this map maps one or more keys to the specified value.
+ ///
+ /// The HashMap key
+ /// true if this map maps one or more keys to the specified value
+ internal bool ContainsValue(object value) =>
+ (bool)Reference.Invoke("containsValue", value);
+
+ ///
+ /// Returns an array of the keys contained in this map.
+ ///
+ /// An array of object hosting the keys contained in the map
+ internal object[] Keys()
+ {
+ var jvmObject = (JvmObjectReference)Reference.Invoke("keySet");
+ var result = (object[])jvmObject.Invoke("toArray");
+ return result;
+ }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/FeatureBase.cs b/src/csharp/Microsoft.Spark/ML/Feature/Base.cs
similarity index 51%
rename from src/csharp/Microsoft.Spark/ML/Feature/FeatureBase.cs
rename to src/csharp/Microsoft.Spark/ML/Feature/Base.cs
index 3a448771..3ad4d935 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/FeatureBase.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Base.cs
@@ -1,4 +1,7 @@
-using System;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
using System.Linq;
using System.Reflection;
using Microsoft.Spark.Interop;
@@ -7,27 +10,23 @@ using Microsoft.Spark.Interop.Ipc;
namespace Microsoft.Spark.ML.Feature
{
///
- /// FeatureBase is to share code amongst all of the ML.Feature objects, there are a few
- /// interfaces that the Scala code implements across all of the objects. This should help to
- /// write the extra objects faster.
+ /// Params is used for components that take parameters. This also provides
+ /// an internal param map to store parameter values attached to the instance.
+ /// An abstract class corresponds to scala's Params trait.
///
- ///
- /// The class that implements FeatureBase, this is needed so we can create new objects where
- /// spark returns new objects rather than update existing objects.
- ///
- public class FeatureBase : Identifiable, IJvmObjectReferenceProvider
- {
- internal FeatureBase(string className)
+ public abstract class Params : Identifiable, IJvmObjectReferenceProvider
+ {
+ internal Params(string className)
: this(SparkEnvironment.JvmBridge.CallConstructor(className))
{
}
-
- internal FeatureBase(string className, string uid)
+
+ internal Params(string className, string uid)
: this(SparkEnvironment.JvmBridge.CallConstructor(className, uid))
{
}
-
- internal FeatureBase(JvmObjectReference jvmObject)
+
+ internal Params(JvmObjectReference jvmObject)
{
Reference = jvmObject;
}
@@ -39,7 +38,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// JVM toString() value
public override string ToString() => (string)Reference.Invoke("toString");
-
+
///
/// The UID that was used to create the object. If no UID is passed in when creating the
/// object then a random UID is created when the object is created.
@@ -47,30 +46,12 @@ namespace Microsoft.Spark.ML.Feature
/// string UID identifying the object
public string Uid() => (string)Reference.Invoke("uid");
- ///
- /// Saves the object so that it can be loaded later using Load. Note that these objects
- /// can be shared with Scala by Loading or Saving in Scala.
- ///
- /// The path to save the object to
- /// New object
- public T Save(string path) =>
- WrapAsType((JvmObjectReference)Reference.Invoke("save", path));
-
- ///
- /// Clears any value that was previously set for this . The value is
- /// reset to the default value.
- ///
- /// The to set back to its original value
- /// Object reference that was used to clear the
- public T Clear(Param.Param param) =>
- WrapAsType((JvmObjectReference)Reference.Invoke("clear", param));
-
///
/// Returns a description of how a specific works and is currently set.
///
/// The to explain
/// Description of the
- public string ExplainParam(Param.Param param) =>
+ public string ExplainParam(Param.Param param) =>
(string)Reference.Invoke("explainParam", param);
///
@@ -80,13 +61,30 @@ namespace Microsoft.Spark.ML.Feature
/// Description of all the applicable 's
public string ExplainParams() => (string)Reference.Invoke("explainParams");
+ /// Checks whether a param is explicitly set.
+ /// The to be checked.
+ /// bool
+ public bool IsSet(Param.Param param) => (bool)Reference.Invoke("isSet", param);
+
+ /// Checks whether a param is explicitly set or has a default value.
+ /// The to be checked.
+ /// bool
+ public bool IsDefined(Param.Param param) => (bool)Reference.Invoke("isDefined", param);
+
+ ///
+ /// Tests whether this instance contains a param with a given name.
+ ///
+ /// The to be test.
+ /// bool
+ public bool HasParam(string paramName) => (bool)Reference.Invoke("hasParam", paramName);
+
///
/// Retrieves a so that it can be used to set the value of the
/// on the object.
///
/// The name of the to get.
/// that can be used to set the actual value
- public Param.Param GetParam(string paramName) =>
+ public Param.Param GetParam(string paramName) =>
new Param.Param((JvmObjectReference)Reference.Invoke("getParam", paramName));
///
@@ -95,10 +93,19 @@ namespace Microsoft.Spark.ML.Feature
/// to set the value of
/// The value to use
/// The object that contains the newly set
- public T Set(Param.Param param, object value) =>
- WrapAsType((JvmObjectReference)Reference.Invoke("set", param, value));
+ public T Set(Param.Param param, object value) =>
+ WrapAsType((JvmObjectReference)Reference.Invoke("set", param, value));
- private static T WrapAsType(JvmObjectReference reference)
+ ///
+ /// Clears any value that was previously set for this . The value is
+ /// reset to the default value.
+ ///
+ /// The to set back to its original value
+ /// Object reference that was used to clear the
+ public T Clear(Param.Param param) =>
+ WrapAsType((JvmObjectReference)Reference.Invoke("clear", param));
+
+ protected static T WrapAsType(JvmObjectReference reference)
{
ConstructorInfo constructor = typeof(T)
.GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance)
@@ -109,7 +116,32 @@ namespace Microsoft.Spark.ML.Feature
(parameters[0].ParameterType == typeof(JvmObjectReference));
});
- return (T)constructor.Invoke(new object[] {reference});
+ return (T)constructor.Invoke(new object[] { reference });
+ }
+ }
+
+ ///
+ /// DotnetUtils is used to hold basic general helper functions that
+ /// are used within ML scope.
+ ///
+ internal class DotnetUtils
+ {
+ ///
+ /// Helper function for getting the exact class name from jvm object.
+ ///
+ /// The reference to object created in JVM.
+ /// A string Tuple2 of constructor class name and method name
+ internal static (string, string) GetUnderlyingType(JvmObjectReference jvmObject)
+ {
+ var jvmClass = (JvmObjectReference)jvmObject.Invoke("getClass");
+ var returnClass = (string)jvmClass.Invoke("getTypeName");
+ string[] dotnetClass = returnClass.Replace("com.microsoft.azure.synapse.ml", "Synapse.ML")
+ .Replace("org.apache.spark.ml", "Microsoft.Spark.ML")
+ .Split(".".ToCharArray());
+ string[] renameClass = dotnetClass.Select(x => char.ToUpper(x[0]) + x.Substring(1)).ToArray();
+ string constructorClass = string.Join(".", renameClass);
+ string methodName = "WrapAs" + dotnetClass[dotnetClass.Length - 1];
+ return (constructorClass, methodName);
}
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs b/src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs
index b3989dbf..ddd7c7c0 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Spark.Interop;
@@ -20,11 +19,14 @@ namespace Microsoft.Spark.ML.Feature
/// will be thrown. The splits parameter is only used for single column usage, and splitsArray
/// is for multiple columns.
///
- public class Bucketizer : FeatureBase
+ public class Bucketizer :
+ JavaModel,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_bucketizerClassName =
+ private static readonly string s_bucketizerClassName =
"org.apache.spark.ml.feature.Bucketizer";
-
+
///
/// Create a without any parameters
///
@@ -40,11 +42,11 @@ namespace Microsoft.Spark.ML.Feature
public Bucketizer(string uid) : base(s_bucketizerClassName, uid)
{
}
-
+
internal Bucketizer(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Gets the splits that were set using SetSplits
///
@@ -62,7 +64,7 @@ namespace Microsoft.Spark.ML.Feature
/// increasing. Values outside the splits specified will be treated as errors.
///
/// New object
- public Bucketizer SetSplits(double[] value) =>
+ public Bucketizer SetSplits(double[] value) =>
WrapAsBucketizer(Reference.Invoke("setSplits", value));
///
@@ -82,7 +84,7 @@ namespace Microsoft.Spark.ML.Feature
/// includes y. The splits should be of length >= 3 and strictly increasing.
/// Values outside the splits specified will be treated as errors.
/// New object
- public Bucketizer SetSplitsArray(double[][] value) =>
+ public Bucketizer SetSplitsArray(double[][] value) =>
WrapAsBucketizer(Reference.Invoke("setSplitsArray", (object)value));
///
@@ -98,15 +100,15 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as the source of the buckets
/// New object
- public Bucketizer SetInputCol(string value) =>
+ public Bucketizer SetInputCol(string value) =>
WrapAsBucketizer(Reference.Invoke("setInputCol", value));
///
/// Gets the columns that should read from and convert into
/// buckets. This is set by SetInputCol
///
- /// IEnumerable<string>, list of input columns
- public IEnumerable GetInputCols() =>
+ /// IEnumerable<string>, list of input columns
+ public IEnumerable GetInputCols() =>
((string[])(Reference.Invoke("getInputCols"))).ToList();
///
@@ -118,7 +120,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// List of input columns to use as sources for buckets
/// New object
- public Bucketizer SetInputCols(IEnumerable value) =>
+ public Bucketizer SetInputCols(IEnumerable value) =>
WrapAsBucketizer(Reference.Invoke("setInputCols", value));
///
@@ -134,7 +136,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the new column which contains the bucket ID
/// New object
- public Bucketizer SetOutputCol(string value) =>
+ public Bucketizer SetOutputCol(string value) =>
WrapAsBucketizer(Reference.Invoke("setOutputCol", value));
///
@@ -142,7 +144,7 @@ namespace Microsoft.Spark.ML.Feature
/// This is set by SetOutputCols
///
/// IEnumerable<string>, list of output columns
- public IEnumerable GetOutputCols() =>
+ public IEnumerable GetOutputCols() =>
((string[])Reference.Invoke("getOutputCols")).ToList();
///
@@ -150,7 +152,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// List of column names which will contain the bucket ID
/// New object
- public Bucketizer SetOutputCols(List value) =>
+ public Bucketizer SetOutputCols(List value) =>
WrapAsBucketizer(Reference.Invoke("setOutputCols", value));
///
@@ -161,7 +163,7 @@ namespace Microsoft.Spark.ML.Feature
public static Bucketizer Load(string path) =>
WrapAsBucketizer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
- s_bucketizerClassName,"load", path));
+ s_bucketizerClassName, "load", path));
///
/// Executes the and transforms the DataFrame to include the new
@@ -171,7 +173,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// containing the original data and the new bucketed columns
///
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -188,10 +190,31 @@ namespace Microsoft.Spark.ML.Feature
///
/// "skip", "error" or "keep"
/// New object
- public Bucketizer SetHandleInvalid(string value) =>
+ public Bucketizer SetHandleInvalid(string value) =>
WrapAsBucketizer(Reference.Invoke("setHandleInvalid", value.ToString()));
- private static Bucketizer WrapAsBucketizer(object obj) =>
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static Bucketizer WrapAsBucketizer(object obj) =>
new Bucketizer((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs b/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs
index f5aa238d..f99c2ced 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs
@@ -8,11 +8,14 @@ using Microsoft.Spark.Sql;
namespace Microsoft.Spark.ML.Feature
{
- public class CountVectorizer : FeatureBase
+ public class CountVectorizer :
+ JavaEstimator,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_countVectorizerClassName =
+ private static readonly string s_countVectorizerClassName =
"org.apache.spark.ml.feature.CountVectorizer";
-
+
///
/// Creates a without any parameters.
///
@@ -28,7 +31,7 @@ namespace Microsoft.Spark.ML.Feature
public CountVectorizer(string uid) : base(s_countVectorizerClassName, uid)
{
}
-
+
internal CountVectorizer(JvmObjectReference jvmObject) : base(jvmObject)
{
}
@@ -36,7 +39,7 @@ namespace Microsoft.Spark.ML.Feature
/// Fits a model to the input data.
/// The to fit the model to.
///
- public CountVectorizerModel Fit(DataFrame dataFrame) =>
+ public override CountVectorizerModel Fit(DataFrame dataFrame) =>
new CountVectorizerModel((JvmObjectReference)Reference.Invoke("fit", dataFrame));
///
@@ -49,8 +52,8 @@ namespace Microsoft.Spark.ML.Feature
public static CountVectorizer Load(string path) =>
WrapAsCountVectorizer((JvmObjectReference)
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
- s_countVectorizerClassName,"load", path));
-
+ s_countVectorizerClassName, "load", path));
+
///
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic
@@ -58,7 +61,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// boolean
public bool GetBinary() => (bool)Reference.Invoke("getBinary");
-
+
///
/// Sets the binary toggle to control the output vector values. If True, all nonzero counts
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic
@@ -75,7 +78,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The input column of type string
public string GetInputCol() => (string)Reference.Invoke("getInputCol");
-
+
///
/// Sets the column that the should read from.
///
@@ -83,14 +86,14 @@ namespace Microsoft.Spark.ML.Feature
/// with the input column set
public CountVectorizer SetInputCol(string value) =>
WrapAsCountVectorizer((JvmObjectReference)Reference.Invoke("setInputCol", value));
-
+
///
/// Gets the name of the new column the creates in the
/// DataFrame.
///
/// The name of the output column.
public string GetOutputCol() => (string)Reference.Invoke("getOutputCol");
-
+
///
/// Sets the name of the new column the creates in the
/// DataFrame.
@@ -99,7 +102,7 @@ namespace Microsoft.Spark.ML.Feature
/// New with the output column set
public CountVectorizer SetOutputCol(string value) =>
WrapAsCountVectorizer((JvmObjectReference)Reference.Invoke("setOutputCol", value));
-
+
///
/// Gets the maximum number of different documents a term could appear in to be included in
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is
@@ -123,7 +126,7 @@ namespace Microsoft.Spark.ML.Feature
[Since(Versions.V2_4_0)]
public CountVectorizer SetMaxDF(double value) =>
WrapAsCountVectorizer((JvmObjectReference)Reference.Invoke("setMaxDF", value));
-
+
///
/// Gets the minimum number of different documents a term must appear in to be included in
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the
@@ -132,7 +135,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The minimum document term frequency
public double GetMinDF() => (double)Reference.Invoke("getMinDF");
-
+
///
/// Sets the minimum number of different documents a term must appear in to be included in
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the
@@ -143,7 +146,7 @@ namespace Microsoft.Spark.ML.Feature
/// New with the min df value set
public CountVectorizer SetMinDF(double value) =>
WrapAsCountVectorizer((JvmObjectReference)Reference.Invoke("setMinDF", value));
-
+
///
/// Gets the filter to ignore rare words in a document. For each document, terms with
/// frequency/count less than the given threshold are ignored. If this is an integer
@@ -171,7 +174,7 @@ namespace Microsoft.Spark.ML.Feature
/// New with the min term frequency set
public CountVectorizer SetMinTF(double value) =>
WrapAsCountVectorizer((JvmObjectReference)Reference.Invoke("setMinTF", value));
-
+
///
/// Gets the max size of the vocabulary. will build a
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across
@@ -179,7 +182,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The max size of the vocabulary of type int.
public int GetVocabSize() => (int)Reference.Invoke("getVocabSize");
-
+
///
/// Sets the max size of the vocabulary. will build a
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across
@@ -187,10 +190,31 @@ namespace Microsoft.Spark.ML.Feature
///
/// The max vocabulary size
/// with the max vocab value set
- public CountVectorizer SetVocabSize(int value) =>
+ public CountVectorizer SetVocabSize(int value) =>
WrapAsCountVectorizer(Reference.Invoke("setVocabSize", value));
-
- private static CountVectorizer WrapAsCountVectorizer(object obj) =>
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static CountVectorizer WrapAsCountVectorizer(object obj) =>
new CountVectorizer((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs b/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs
index 81c22e2d..303ba8ee 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs
@@ -10,16 +10,19 @@ using Microsoft.Spark.Sql.Types;
namespace Microsoft.Spark.ML.Feature
{
- public class CountVectorizerModel : FeatureBase
+ public class CountVectorizerModel :
+ JavaModel,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_countVectorizerModelClassName =
+ private static readonly string s_countVectorizerModelClassName =
"org.apache.spark.ml.feature.CountVectorizerModel";
-
+
///
/// Creates a without any parameters
///
/// The vocabulary to use
- public CountVectorizerModel(List vocabulary)
+ public CountVectorizerModel(List vocabulary)
: this(SparkEnvironment.JvmBridge.CallConstructor(
s_countVectorizerModelClassName, vocabulary))
{
@@ -31,16 +34,16 @@ namespace Microsoft.Spark.ML.Feature
///
/// An immutable unique ID for the object and its derivatives.
/// The vocabulary to use
- public CountVectorizerModel(string uid, List vocabulary)
+ public CountVectorizerModel(string uid, List vocabulary)
: this(SparkEnvironment.JvmBridge.CallConstructor(
s_countVectorizerModelClassName, uid, vocabulary))
{
}
-
+
internal CountVectorizerModel(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Loads the that was previously saved using Save
///
@@ -52,7 +55,7 @@ namespace Microsoft.Spark.ML.Feature
WrapAsCountVectorizerModel(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_countVectorizerModelClassName, "load", path));
-
+
///
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts
/// (after minTF filter applied) are set to 1. This is useful for discrete probabilistic
@@ -79,7 +82,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// string, the input column
public string GetInputCol() => (string)Reference.Invoke("getInputCol");
-
+
///
/// Sets the column that the should read from.
///
@@ -87,14 +90,14 @@ namespace Microsoft.Spark.ML.Feature
/// with the input column set
public CountVectorizerModel SetInputCol(string value) =>
WrapAsCountVectorizerModel(Reference.Invoke("setInputCol", value));
-
+
///
/// Gets the name of the new column the will create in
/// the DataFrame.
///
/// The name of the output column.
public string GetOutputCol() => (string)Reference.Invoke("getOutputCol");
-
+
///
/// Sets the name of the new column the will create in
/// the DataFrame.
@@ -103,7 +106,7 @@ namespace Microsoft.Spark.ML.Feature
/// New with the output column set
public CountVectorizerModel SetOutputCol(string value) =>
WrapAsCountVectorizerModel(Reference.Invoke("setOutputCol", value));
-
+
///
/// Gets the maximum number of different documents a term could appear in to be included in
/// the vocabulary. A term that appears more than the threshold will be ignored. If this is
@@ -113,7 +116,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The maximum document term frequency of type double.
public double GetMaxDF() => (double)Reference.Invoke("getMaxDF");
-
+
///
/// Gets the minimum number of different documents a term must appear in to be included in
/// the vocabulary. If this is an integer greater than or equal to 1, this specifies the
@@ -152,7 +155,7 @@ namespace Microsoft.Spark.ML.Feature
///
public CountVectorizerModel SetMinTF(double value) =>
WrapAsCountVectorizerModel(Reference.Invoke("setMinTF", value));
-
+
///
/// Gets the max size of the vocabulary. will build a
/// vocabulary that only considers the top vocabSize terms ordered by term frequency across
@@ -160,7 +163,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The max size of the vocabulary of type int.
public int GetVocabSize() => (int)Reference.Invoke("getVocabSize");
-
+
///
/// Check transform validity and derive the output schema from the input schema.
///
@@ -177,21 +180,42 @@ namespace Microsoft.Spark.ML.Feature
/// The of the output schema that would have been derived from the
/// input schema, if Transform had been called.
///
- public StructType TransformSchema(StructType value) =>
+ public override StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)Reference.Invoke(
- "transformSchema",
+ "transformSchema",
DataType.FromJson(Reference.Jvm, value.Json)));
-
+
///
/// Converts a DataFrame with a text document to a sparse vector of token counts.
///
/// to transform
/// containing the original data and the counts
- public DataFrame Transform(DataFrame document) =>
+ public override DataFrame Transform(DataFrame document) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", document));
-
- private static CountVectorizerModel WrapAsCountVectorizerModel(object obj) =>
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static CountVectorizerModel WrapAsCountVectorizerModel(object obj) =>
new CountVectorizerModel((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Estimator.cs b/src/csharp/Microsoft.Spark/ML/Feature/Estimator.cs
new file mode 100644
index 00000000..2208424a
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Estimator.cs
@@ -0,0 +1,45 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// A helper interface for JavaEstimator, so that when we have an array of JavaEstimators
+ /// with different type params, we can hold all of them with Estimator<object>.
+ ///
+ public interface IEstimator
+ {
+ M Fit(DataFrame dataset);
+ }
+
+ ///
+ /// Abstract Class for estimators that fit models to data.
+ ///
+ ///
+ public abstract class JavaEstimator : JavaPipelineStage, IEstimator where M : JavaModel
+ {
+ internal JavaEstimator(string className) : base(className)
+ {
+ }
+
+ internal JavaEstimator(string className, string uid) : base(className, uid)
+ {
+ }
+
+ internal JavaEstimator(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Fits a model to the input data.
+ ///
+ /// input dataset.
+ /// fitted model
+ public virtual M Fit(DataFrame dataset) =>
+ WrapAsType((JvmObjectReference)Reference.Invoke("fit", dataset));
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Evaluator.cs b/src/csharp/Microsoft.Spark/ML/Feature/Evaluator.cs
new file mode 100644
index 00000000..1d0deef8
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Evaluator.cs
@@ -0,0 +1,45 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// Abstract Class for evaluators that compute metrics from predictions.
+ ///
+ public abstract class JavaEvaluator : Params
+ {
+ internal JavaEvaluator(string className) : base(className)
+ {
+ }
+
+ internal JavaEvaluator(string className, string uid) : base(className, uid)
+ {
+ }
+
+ internal JavaEvaluator(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Evaluates model output and returns a scalar metric.
+ /// The value of isLargerBetter specifies whether larger values are better.
+ ///
+ /// a dataset that contains labels/observations and predictions.
+ /// metric
+ public virtual double Evaluate(DataFrame dataset) =>
+ (double)Reference.Invoke("evaluate", dataset);
+
+ ///
+ /// Indicates whether the metric returned by evaluate should be maximized
+ /// (true, default) or minimized (false).
+ /// A given evaluator may support multiple metrics which may be maximized or minimized.
+ ///
+ /// bool
+ public bool IsLargerBetter =>
+ (bool)Reference.Invoke("isLargerBetter");
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs b/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs
index c79b6411..5c844396 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs
@@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
-using System.Linq;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
@@ -11,11 +10,14 @@ using Microsoft.Spark.Sql.Types;
namespace Microsoft.Spark.ML.Feature
{
- public class FeatureHasher: FeatureBase
+ public class FeatureHasher :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_featureHasherClassName =
+ private static readonly string s_featureHasherClassName =
"org.apache.spark.ml.feature.FeatureHasher";
-
+
///
/// Creates a without any parameters.
///
@@ -35,7 +37,7 @@ namespace Microsoft.Spark.ML.Feature
internal FeatureHasher(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Loads the that was previously saved using Save.
///
@@ -49,22 +51,22 @@ namespace Microsoft.Spark.ML.Feature
s_featureHasherClassName,
"load",
path));
-
+
///
/// Gets a list of the columns which have been specified as categorical columns.
///
/// List of categorical columns, set by SetCategoricalCols
- public IEnumerable GetCategoricalCols() =>
+ public IEnumerable GetCategoricalCols() =>
(string[])Reference.Invoke("getCategoricalCols");
-
+
///
/// Marks columns as categorical columns.
///
/// List of column names to mark as categorical columns
/// New object
- public FeatureHasher SetCategoricalCols(IEnumerable value) =>
+ public FeatureHasher SetCategoricalCols(IEnumerable value) =>
WrapAsFeatureHasher(Reference.Invoke("setCategoricalCols", value));
-
+
///
/// Gets the columns that the should read from and convert into
/// hashes. This would have been set by SetInputCol.
@@ -78,9 +80,9 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as use the source of the hash
/// New object
- public FeatureHasher SetInputCols(IEnumerable value) =>
+ public FeatureHasher SetInputCols(IEnumerable value) =>
WrapAsFeatureHasher(Reference.Invoke("setInputCols", value));
-
+
///
/// Gets the number of features that should be used. Since a simple modulo is used to
/// transform the hash function to a column index, it is advisable to use a power of two
@@ -98,9 +100,9 @@ namespace Microsoft.Spark.ML.Feature
///
/// int value of number of features
/// New object
- public FeatureHasher SetNumFeatures(int value) =>
+ public FeatureHasher SetNumFeatures(int value) =>
WrapAsFeatureHasher(Reference.Invoke("setNumFeatures", value));
-
+
///
/// Gets the name of the column the output data will be written to. This is set by
/// SetInputCol.
@@ -113,18 +115,18 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the new column which will contain the hash
/// New object
- public FeatureHasher SetOutputCol(string value) =>
+ public FeatureHasher SetOutputCol(string value) =>
WrapAsFeatureHasher(Reference.Invoke("setOutputCol", value));
-
+
///
/// Transforms the input . It is recommended that you validate that
/// the transform will succeed by calling TransformSchema.
///
/// Input to transform
/// Transformed
- public DataFrame Transform(DataFrame value) =>
+ public override DataFrame Transform(DataFrame value) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", value));
-
+
///
/// Check transform validity and derive the output schema from the input schema.
///
@@ -141,13 +143,34 @@ namespace Microsoft.Spark.ML.Feature
/// The of the output schema that would have been derived from the
/// input schema, if Transform had been called.
///
- public StructType TransformSchema(StructType value) =>
+ public override StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)Reference.Invoke(
- "transformSchema",
+ "transformSchema",
DataType.FromJson(Reference.Jvm, value.Json)));
- private static FeatureHasher WrapAsFeatureHasher(object obj) =>
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static FeatureHasher WrapAsFeatureHasher(object obj) =>
new FeatureHasher((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs b/src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs
index 418640e7..2ec9b391 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs
@@ -2,12 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Collections.Generic;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
-using Microsoft.Spark.Sql.Types;
namespace Microsoft.Spark.ML.Feature
{
@@ -19,9 +16,12 @@ namespace Microsoft.Spark.ML.Feature
/// power of two as the numFeatures parameter; otherwise the features will not be mapped evenly
/// to the columns.
///
- public class HashingTF : FeatureBase
+ public class HashingTF :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_hashingTfClassName =
+ private static readonly string s_hashingTfClassName =
"org.apache.spark.ml.feature.HashingTF";
///
@@ -39,7 +39,7 @@ namespace Microsoft.Spark.ML.Feature
public HashingTF(string uid) : base(s_hashingTfClassName, uid)
{
}
-
+
internal HashingTF(JvmObjectReference jvmObject) : base(jvmObject)
{
}
@@ -66,7 +66,7 @@ namespace Microsoft.Spark.ML.Feature
/// models that model binary events rather than integer counts
///
/// binary toggle, default is false
- public HashingTF SetBinary(bool value) =>
+ public HashingTF SetBinary(bool value) =>
WrapAsHashingTF(Reference.Invoke("setBinary", value));
///
@@ -80,7 +80,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as the source
/// New object
- public HashingTF SetInputCol(string value) =>
+ public HashingTF SetInputCol(string value) =>
WrapAsHashingTF(Reference.Invoke("setInputCol", value));
///
@@ -96,7 +96,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the new column
/// New object
- public HashingTF SetOutputCol(string value) =>
+ public HashingTF SetOutputCol(string value) =>
WrapAsHashingTF(Reference.Invoke("setOutputCol", value));
///
@@ -116,7 +116,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// int
/// New object
- public HashingTF SetNumFeatures(int value) =>
+ public HashingTF SetNumFeatures(int value) =>
WrapAsHashingTF(Reference.Invoke("setNumFeatures", value));
///
@@ -125,10 +125,31 @@ namespace Microsoft.Spark.ML.Feature
///
/// The to add the tokens to
/// containing the original data and the tokens
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
- private static HashingTF WrapAsHashingTF(object obj) =>
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static HashingTF WrapAsHashingTF(object obj) =>
new HashingTF((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/IDF.cs b/src/csharp/Microsoft.Spark/ML/Feature/IDF.cs
index d4f50168..7c9b5dc1 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/IDF.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/IDF.cs
@@ -17,10 +17,13 @@ namespace Microsoft.Spark.ML.Feature
/// of documents (controlled by the variable minDocFreq). For terms that are not in at least
/// minDocFreq documents, the IDF is found as 0, resulting in TF-IDFs of 0.
///
- public class IDF : FeatureBase
+ public class IDF :
+ JavaEstimator,
+ IJavaMLWritable,
+ IJavaMLReadable
{
private static readonly string s_IDFClassName = "org.apache.spark.ml.feature.IDF";
-
+
///
/// Create a without any parameters
///
@@ -36,11 +39,11 @@ namespace Microsoft.Spark.ML.Feature
public IDF(string uid) : base(s_IDFClassName, uid)
{
}
-
+
internal IDF(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Gets the column that the should read from
///
@@ -67,7 +70,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the new column
/// New object
- public IDF SetOutputCol(string value) =>
+ public IDF SetOutputCol(string value) =>
WrapAsIDF(Reference.Invoke("setOutputCol", value));
///
@@ -81,7 +84,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// int, the minimum of documents a term should appear in
/// New object
- public IDF SetMinDocFreq(int value) =>
+ public IDF SetMinDocFreq(int value) =>
WrapAsIDF(Reference.Invoke("setMinDocFreq", value));
///
@@ -89,7 +92,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The to fit the model to
/// New object
- public IDFModel Fit(DataFrame source) =>
+ public override IDFModel Fit(DataFrame source) =>
new IDFModel((JvmObjectReference)Reference.Invoke("fit", source));
///
@@ -102,7 +105,28 @@ namespace Microsoft.Spark.ML.Feature
return WrapAsIDF(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_IDFClassName, "load", path));
}
-
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
private static IDF WrapAsIDF(object obj) => new IDF((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs b/src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs
index c40b35e9..91a577ec 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs
@@ -12,11 +12,14 @@ namespace Microsoft.Spark.ML.Feature
/// A that converts the input string to lowercase and then splits it by
/// white spaces.
///
- public class IDFModel : FeatureBase
+ public class IDFModel :
+ JavaModel,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_IDFModelClassName =
+ private static readonly string s_IDFModelClassName =
"org.apache.spark.ml.feature.IDFModel";
-
+
///
/// Create a without any parameters
///
@@ -32,11 +35,11 @@ namespace Microsoft.Spark.ML.Feature
public IDFModel(string uid) : base(s_IDFModelClassName, uid)
{
}
-
+
internal IDFModel(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Gets the column that the should read from
///
@@ -49,7 +52,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as the source
/// New object
- public IDFModel SetInputCol(string value) =>
+ public IDFModel SetInputCol(string value) =>
WrapAsIDFModel(Reference.Invoke("setInputCol", value));
///
@@ -66,7 +69,7 @@ namespace Microsoft.Spark.ML.Feature
/// The name of the new column which contains the tokens
///
/// New object
- public IDFModel SetOutputCol(string value) =>
+ public IDFModel SetOutputCol(string value) =>
WrapAsIDFModel(Reference.Invoke("setOutputCol", value));
///
@@ -81,7 +84,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The to add the tokens to
/// containing the original data and the tokens
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -96,7 +99,28 @@ namespace Microsoft.Spark.ML.Feature
s_IDFModelClassName, "load", path));
}
- private static IDFModel WrapAsIDFModel(object obj) =>
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static IDFModel WrapAsIDFModel(object obj) =>
new IDFModel((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Model.cs b/src/csharp/Microsoft.Spark/ML/Feature/Model.cs
new file mode 100644
index 00000000..80960717
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Model.cs
@@ -0,0 +1,53 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// A helper interface for JavaModel, so that when we have an array of JavaModels
+ /// with different type params, we can hold all of them with Model<object>.
+ ///
+ public interface IModel
+ {
+ bool HasParent();
+ }
+
+ ///
+ /// A fitted model, i.e., a Transformer produced by an Estimator.
+ ///
+ ///
+ /// Model Type.
+ ///
+ public abstract class JavaModel : JavaTransformer, IModel where M : JavaModel
+ {
+ internal JavaModel(string className) : base(className)
+ {
+ }
+
+ internal JavaModel(string className, string uid) : base(className, uid)
+ {
+ }
+
+ internal JavaModel(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Sets the parent of this model.
+ ///
+ /// The parent of the JavaModel to be set
+ /// type parameter M
+ public M SetParent(JavaEstimator parent) =>
+ WrapAsType((JvmObjectReference)Reference.Invoke("setParent", parent));
+
+ ///
+ /// Indicates whether this Model has a corresponding parent.
+ ///
+ /// bool
+ public bool HasParent() =>
+ (bool)Reference.Invoke("hasParent");
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/NGram.cs b/src/csharp/Microsoft.Spark/ML/Feature/NGram.cs
index b5e9be1a..fd0995d6 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/NGram.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/NGram.cs
@@ -14,7 +14,10 @@ namespace Microsoft.Spark.ML.Feature
/// an array of n-grams. Null values in the input array are ignored. It returns an array
/// of n-grams where each n-gram is represented by a space-separated string of words.
///
- public class NGram : FeatureBase
+ public class NGram :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
private static readonly string s_nGramClassName =
"org.apache.spark.ml.feature.NGram";
@@ -87,7 +90,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// New object with the source transformed.
///
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -106,7 +109,7 @@ namespace Microsoft.Spark.ML.Feature
/// The of the output schema that would have been derived from the
/// input schema, if Transform had been called.
///
- public StructType TransformSchema(StructType value) =>
+ public override StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)Reference.Invoke(
"transformSchema",
@@ -124,6 +127,27 @@ namespace Microsoft.Spark.ML.Feature
"load",
path));
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
private static NGram WrapAsNGram(object obj) => new NGram((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs b/src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs
new file mode 100644
index 00000000..1e83f4ed
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs
@@ -0,0 +1,120 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Reflection;
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Sql;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// A simple pipeline, which acts as an estimator.
+ /// A Pipeline consists of a sequence of stages, each of which is either an Estimator or a Transformer.
+ /// When Pipeline.fit is called, the stages are executed in order. If a stage is an Estimator, its
+ /// Estimator.fit method will be called on the input dataset to fit a model. Then the model, which is a
+ /// transformer, will be used to transform the dataset as the input to the next stage.
+ /// If a stage is a Transformer, its Transformer.transform method will be called to produce the
+ /// dataset for the next stage. The fitted model from a Pipeline is a PipelineModel, which consists of
+ /// fitted models and transformers, corresponding to the pipeline
+ /// stages. If there are no stages, the pipeline acts as an identity transformer.
+ ///
+ public class Pipeline :
+ JavaEstimator,
+ IJavaMLWritable,
+ IJavaMLReadable
+ {
+ private static readonly string s_pipelineClassName = "org.apache.spark.ml.Pipeline";
+
+ ///
+ /// Creates a without any parameters.
+ ///
+ public Pipeline() : base(s_pipelineClassName)
+ {
+ }
+
+ ///
+ /// Creates a with a UID that is used to give the
+ /// a unique ID.
+ ///
+ /// An immutable unique ID for the object and its derivatives.
+ public Pipeline(string uid) : base(s_pipelineClassName, uid)
+ {
+ }
+
+ internal Pipeline(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Set the stages of pipeline instance.
+ ///
+ ///
+ /// A sequence of stages, each of which is either an Estimator or a Transformer.
+ ///
+ /// object
+ public Pipeline SetStages(JavaPipelineStage[] value) =>
+ WrapAsPipeline((JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
+ "org.apache.spark.mllib.api.dotnet.MLUtils", "setPipelineStages",
+ Reference, value.ToJavaArrayList()));
+
+ ///
+ /// Get the stages of pipeline instance.
+ ///
+ /// A sequence of stages
+ public JavaPipelineStage[] GetStages()
+ {
+ JvmObjectReference[] jvmObjects = (JvmObjectReference[])Reference.Invoke("getStages");
+ JavaPipelineStage[] result = new JavaPipelineStage[jvmObjects.Length];
+ for (int i = 0; i < jvmObjects.Length; i++)
+ {
+ (string constructorClass, string methodName) = DotnetUtils.GetUnderlyingType(jvmObjects[i]);
+ Type type = Type.GetType(constructorClass);
+ MethodInfo method = type.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Static);
+ result[i] = (JavaPipelineStage)method.Invoke(null, new object[] { jvmObjects[i] });
+ }
+ return result;
+ }
+
+ /// Fits a model to the input data.
+ /// The to fit the model to.
+ ///
+ override public PipelineModel Fit(DataFrame dataset) =>
+ new PipelineModel(
+ (JvmObjectReference)Reference.Invoke("fit", dataset));
+
+ ///
+ /// Loads the that was previously saved using Save(string).
+ ///
+ /// The path the previous was saved to
+ /// New object, loaded from path.
+ public static Pipeline Load(string path) => WrapAsPipeline(
+ SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_pipelineClassName, "load", path));
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static Pipeline WrapAsPipeline(object obj) =>
+ new Pipeline((JvmObjectReference)obj);
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/PipelineModel.cs b/src/csharp/Microsoft.Spark/ML/Feature/PipelineModel.cs
new file mode 100644
index 00000000..e9848b75
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/PipelineModel.cs
@@ -0,0 +1,70 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// Represents a fitted pipeline.
+ ///
+ public class PipelineModel :
+ JavaModel,
+ IJavaMLWritable,
+ IJavaMLReadable
+ {
+ private static readonly string s_pipelineModelClassName = "org.apache.spark.ml.PipelineModel";
+
+ ///
+ /// Creates a with a UID that is used to give the
+ /// a unique ID, and an array of transformers as stages.
+ ///
+ /// An immutable unique ID for the object and its derivatives.
+ /// Stages for the PipelineModel.
+ public PipelineModel(string uid, JavaTransformer[] stages)
+ : this(SparkEnvironment.JvmBridge.CallConstructor(
+ s_pipelineModelClassName, uid, stages.ToJavaArrayList()))
+ {
+ }
+
+ internal PipelineModel(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Loads the that was previously saved using Save(string).
+ ///
+ /// The path the previous was saved to
+ /// New object, loaded from path.
+ public static PipelineModel Load(string path) => WrapAsPipelineModel(
+ SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_pipelineModelClassName, "load", path));
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static PipelineModel WrapAsPipelineModel(object obj) =>
+ new PipelineModel((JvmObjectReference)obj);
+
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/PipelineStage.cs b/src/csharp/Microsoft.Spark/ML/Feature/PipelineStage.cs
new file mode 100644
index 00000000..3420b378
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/PipelineStage.cs
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Sql.Types;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// A stage in a pipeline, either an Estimator or a Transformer.
+ ///
+ public abstract class JavaPipelineStage : Params
+ {
+ internal JavaPipelineStage(string className) : base(className)
+ {
+ }
+
+ internal JavaPipelineStage(string className, string uid) : base(className, uid)
+ {
+ }
+
+ internal JavaPipelineStage(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Check transform validity and derive the output schema from the input schema.
+ ///
+ /// We check validity for interactions between parameters during transformSchema
+ /// and raise an exception if any parameter value is invalid.
+ ///
+ /// Typical implementation should first conduct verification on schema change and
+ /// parameter validity, including complex parameter interaction checks.
+ ///
+ ///
+ /// The of the which will be transformed.
+ ///
+ ///
+ /// The of the output schema that would have been derived from the
+ /// input schema, if Transform had been called.
+ ///
+ public virtual StructType TransformSchema(StructType schema) =>
+ new StructType(
+ (JvmObjectReference)Reference.Invoke(
+ "transformSchema",
+ DataType.FromJson(Reference.Jvm, schema.Json)));
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs b/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs
index c6f8f299..530bbc27 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs
@@ -12,9 +12,12 @@ namespace Microsoft.Spark.ML.Feature
///
/// implements the transformations which are defined by SQL statement.
///
- public class SQLTransformer : FeatureBase
+ public class SQLTransformer :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_sqlTransformerClassName =
+ private static readonly string s_sqlTransformerClassName =
"org.apache.spark.ml.feature.SQLTransformer";
///
@@ -45,7 +48,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// New object with the source transformed.
///
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -55,7 +58,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// New object with the schema transformed.
///
- public StructType TransformSchema(StructType value) =>
+ public override StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)Reference.Invoke(
"transformSchema",
@@ -82,13 +85,34 @@ namespace Microsoft.Spark.ML.Feature
///
/// The path the previous was saved to
/// New object, loaded from path
- public static SQLTransformer Load(string path) =>
+ public static SQLTransformer Load(string path) =>
WrapAsSQLTransformer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
- s_sqlTransformerClassName,
- "load",
+ s_sqlTransformerClassName,
+ "load",
path));
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
private static SQLTransformer WrapAsSQLTransformer(object obj) =>
new SQLTransformer((JvmObjectReference)obj);
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs b/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs
index a1484155..64ec7585 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/StopWordsRemover.cs
@@ -13,7 +13,10 @@ namespace Microsoft.Spark.ML.Feature
///
/// A feature transformer that filters out stop words from input.
///
- public class StopWordsRemover : FeatureBase
+ public class StopWordsRemover :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
private static readonly string s_stopWordsRemoverClassName =
"org.apache.spark.ml.feature.StopWordsRemover";
@@ -63,7 +66,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// New object with the source transformed
///
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -141,7 +144,7 @@ namespace Microsoft.Spark.ML.Feature
/// The of the output schema that would have been derived from the
/// input schema, if Transform had been called.
///
- public StructType TransformSchema(StructType value) =>
+ public override StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)Reference.Invoke(
"transformSchema",
@@ -168,6 +171,27 @@ namespace Microsoft.Spark.ML.Feature
WrapAsStopWordsRemover(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_stopWordsRemoverClassName, "load", path));
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
private static StopWordsRemover WrapAsStopWordsRemover(object obj) =>
new StopWordsRemover((JvmObjectReference)obj);
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Tokenizer.cs b/src/csharp/Microsoft.Spark/ML/Feature/Tokenizer.cs
index 3cf81e23..246f01d7 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/Tokenizer.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Tokenizer.cs
@@ -12,11 +12,14 @@ namespace Microsoft.Spark.ML.Feature
/// A that converts the input string to lowercase and then splits it by
/// white spaces.
///
- public class Tokenizer : FeatureBase
+ public class Tokenizer :
+ JavaTransformer,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_tokenizerClassName =
+ private static readonly string s_tokenizerClassName =
"org.apache.spark.ml.feature.Tokenizer";
-
+
///
/// Create a without any parameters
///
@@ -32,11 +35,11 @@ namespace Microsoft.Spark.ML.Feature
public Tokenizer(string uid) : base(s_tokenizerClassName, uid)
{
}
-
+
internal Tokenizer(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Gets the column that the should read from
///
@@ -48,7 +51,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as the source
/// New object
- public Tokenizer SetInputCol(string value) =>
+ public Tokenizer SetInputCol(string value) =>
WrapAsTokenizer(Reference.Invoke("setInputCol", value));
///
@@ -64,7 +67,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the new column
/// New object
- public Tokenizer SetOutputCol(string value) =>
+ public Tokenizer SetOutputCol(string value) =>
WrapAsTokenizer(Reference.Invoke("setOutputCol", value));
///
@@ -75,7 +78,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// New object with the source transformed
///
- public DataFrame Transform(DataFrame source) =>
+ public override DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", source));
///
@@ -89,8 +92,29 @@ namespace Microsoft.Spark.ML.Feature
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_tokenizerClassName, "load", path));
}
-
- private static Tokenizer WrapAsTokenizer(object obj) =>
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static Tokenizer WrapAsTokenizer(object obj) =>
new Tokenizer((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Transformer.cs b/src/csharp/Microsoft.Spark/ML/Feature/Transformer.cs
new file mode 100644
index 00000000..cea34c4c
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Transformer.cs
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Sql;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// Abstract class for transformers that transform one dataset into another.
+ ///
+ public abstract class JavaTransformer : JavaPipelineStage
+ {
+ internal JavaTransformer(string className) : base(className)
+ {
+ }
+
+ internal JavaTransformer(string className, string uid) : base(className, uid)
+ {
+ }
+
+ internal JavaTransformer(JvmObjectReference jvmObject) : base(jvmObject)
+ {
+ }
+
+ ///
+ /// Executes the transformer and transforms the DataFrame to include new columns.
+ ///
+ /// The Dataframe to be transformed.
+ ///
+ /// containing the original data and new columns.
+ ///
+ public virtual DataFrame Transform(DataFrame dataset) =>
+ new DataFrame((JvmObjectReference)Reference.Invoke("transform", dataset));
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Word2Vec.cs b/src/csharp/Microsoft.Spark/ML/Feature/Word2Vec.cs
index 42dac10f..5e09218f 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/Word2Vec.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Word2Vec.cs
@@ -8,9 +8,12 @@ using Microsoft.Spark.Sql;
namespace Microsoft.Spark.ML.Feature
{
- public class Word2Vec : FeatureBase
+ public class Word2Vec :
+ JavaEstimator,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_word2VecClassName =
+ private static readonly string s_word2VecClassName =
"org.apache.spark.ml.feature.Word2Vec";
///
@@ -28,11 +31,11 @@ namespace Microsoft.Spark.ML.Feature
public Word2Vec(string uid) : base(s_word2VecClassName, uid)
{
}
-
+
internal Word2Vec(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Gets the column that the should read from.
///
@@ -44,7 +47,7 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the column to as the source.
///
- public Word2Vec SetInputCol(string value) =>
+ public Word2Vec SetInputCol(string value) =>
WrapAsWord2Vec(Reference.Invoke("setInputCol", value));
///
@@ -60,9 +63,9 @@ namespace Microsoft.Spark.ML.Feature
///
/// The name of the output column which will be created.
/// New
- public Word2Vec SetOutputCol(string value) =>
+ public Word2Vec SetOutputCol(string value) =>
WrapAsWord2Vec(Reference.Invoke("setOutputCol", value));
-
+
///
/// Gets the vector size, the dimension of the code that you want to transform from words.
///
@@ -70,7 +73,7 @@ namespace Microsoft.Spark.ML.Feature
/// The vector size, the dimension of the code that you want to transform from words.
///
public int GetVectorSize() => (int)(Reference.Invoke("getVectorSize"));
-
+
///
/// Sets the vector size, the dimension of the code that you want to transform from words.
///
@@ -78,7 +81,7 @@ namespace Microsoft.Spark.ML.Feature
/// The dimension of the code that you want to transform from words.
///
///
- public Word2Vec SetVectorSize(int value) =>
+ public Word2Vec SetVectorSize(int value) =>
WrapAsWord2Vec(Reference.Invoke("setVectorSize", value));
///
@@ -100,9 +103,9 @@ namespace Microsoft.Spark.ML.Feature
/// vocabulary, the default is 5.
///
///
- public virtual Word2Vec SetMinCount(int value) =>
+ public virtual Word2Vec SetMinCount(int value) =>
WrapAsWord2Vec(Reference.Invoke("setMinCount", value));
-
+
/// Gets the maximum number of iterations.
/// The maximum number of iterations.
public int GetMaxIter() => (int)Reference.Invoke("getMaxIter");
@@ -110,14 +113,14 @@ namespace Microsoft.Spark.ML.Feature
/// Maximum number of iterations (>= 0).
/// The number of iterations.
///
- public Word2Vec SetMaxIter(int value) =>
+ public Word2Vec SetMaxIter(int value) =>
WrapAsWord2Vec(Reference.Invoke("setMaxIter", value));
///
/// Gets the maximum length (in words) of each sentence in the input data.
///
/// The maximum length (in words) of each sentence in the input data.
- public virtual int GetMaxSentenceLength() =>
+ public virtual int GetMaxSentenceLength() =>
(int)Reference.Invoke("getMaxSentenceLength");
///
@@ -127,13 +130,13 @@ namespace Microsoft.Spark.ML.Feature
/// The maximum length (in words) of each sentence in the input data.
///
///
- public Word2Vec SetMaxSentenceLength(int value) =>
+ public Word2Vec SetMaxSentenceLength(int value) =>
WrapAsWord2Vec(Reference.Invoke("setMaxSentenceLength", value));
/// Gets the number of partitions for sentences of words.
/// The number of partitions for sentences of words.
public int GetNumPartitions() => (int)Reference.Invoke("getNumPartitions");
-
+
/// Sets the number of partitions for sentences of words.
///
/// The number of partitions for sentences of words, default is 1.
@@ -145,7 +148,7 @@ namespace Microsoft.Spark.ML.Feature
/// Gets the value that is used for the random seed.
/// The value that is used for the random seed.
public long GetSeed() => (long)Reference.Invoke("getSeed");
-
+
/// Random seed.
/// The value to use for the random seed.
///
@@ -155,7 +158,7 @@ namespace Microsoft.Spark.ML.Feature
/// Gets the size to be used for each iteration of optimization.
/// The size to be used for each iteration of optimization.
public double GetStepSize() => (double)Reference.Invoke("getStepSize");
-
+
/// Step size to be used for each iteration of optimization (> 0).
/// Value to use for the step size.
///
@@ -165,7 +168,7 @@ namespace Microsoft.Spark.ML.Feature
/// Gets the window size (context words from [-window, window]).
/// The window size.
public int GetWindowSize() => (int)Reference.Invoke("getWindowSize");
-
+
/// The window size (context words from [-window, window]).
///
/// The window size (context words from [-window, window]), default is 5.
@@ -173,11 +176,11 @@ namespace Microsoft.Spark.ML.Feature
///
public Word2Vec SetWindowSize(int value) =>
WrapAsWord2Vec(Reference.Invoke("setWindowSize", value));
-
+
/// Fits a model to the input data.
/// The to fit the model to.
///
- public Word2VecModel Fit(DataFrame dataFrame) =>
+ public override Word2VecModel Fit(DataFrame dataFrame) =>
new Word2VecModel((JvmObjectReference)Reference.Invoke("fit", dataFrame));
///
@@ -187,8 +190,29 @@ namespace Microsoft.Spark.ML.Feature
/// New object, loaded from path.
public static Word2Vec Load(string path) => WrapAsWord2Vec(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_word2VecClassName, "load", path));
-
- private static Word2Vec WrapAsWord2Vec(object obj) =>
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static Word2Vec WrapAsWord2Vec(object obj) =>
new Word2Vec((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Feature/Word2VecModel.cs b/src/csharp/Microsoft.Spark/ML/Feature/Word2VecModel.cs
index ffdb7c7c..319223f1 100644
--- a/src/csharp/Microsoft.Spark/ML/Feature/Word2VecModel.cs
+++ b/src/csharp/Microsoft.Spark/ML/Feature/Word2VecModel.cs
@@ -8,9 +8,12 @@ using Microsoft.Spark.Sql;
namespace Microsoft.Spark.ML.Feature
{
- public class Word2VecModel : FeatureBase
+ public class Word2VecModel :
+ JavaModel,
+ IJavaMLWritable,
+ IJavaMLReadable
{
- private static readonly string s_word2VecModelClassName =
+ private static readonly string s_word2VecModelClassName =
"org.apache.spark.ml.feature.Word2VecModel";
///
@@ -28,18 +31,18 @@ namespace Microsoft.Spark.ML.Feature
public Word2VecModel(string uid) : base(s_word2VecModelClassName, uid)
{
}
-
+
internal Word2VecModel(JvmObjectReference jvmObject) : base(jvmObject)
{
}
-
+
///
/// Transform a sentence column to a vector column to represent the whole sentence.
///
/// to transform
- public DataFrame Transform(DataFrame documentDF) =>
+ public override DataFrame Transform(DataFrame documentDF) =>
new DataFrame((JvmObjectReference)Reference.Invoke("transform", documentDF));
-
+
///
/// Find number of words whose vector representation most similar to
/// the supplied vector. If the supplied vector is the vector representation of a word in
@@ -51,7 +54,7 @@ namespace Microsoft.Spark.ML.Feature
/// The number of words to find that are similar to "word"
public DataFrame FindSynonyms(string word, int num) =>
new DataFrame((JvmObjectReference)Reference.Invoke("findSynonyms", word, num));
-
+
///
/// Loads the that was previously saved using Save(string).
///
@@ -62,8 +65,29 @@ namespace Microsoft.Spark.ML.Feature
public static Word2VecModel Load(string path) => WrapAsWord2VecModel(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_word2VecModelClassName, "load", path));
-
- private static Word2VecModel WrapAsWord2VecModel(object obj) =>
+
+ ///
+ /// Saves the object so that it can be loaded later using Load. Note that these objects
+ /// can be shared with Scala by Loading or Saving in Scala.
+ ///
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ public JavaMLWriter Write() =>
+ new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));
+
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ public JavaMLReader Read() =>
+ new JavaMLReader((JvmObjectReference)Reference.Invoke("read"));
+
+ private static Word2VecModel WrapAsWord2VecModel(object obj) =>
new Word2VecModel((JvmObjectReference)obj);
}
}
diff --git a/src/csharp/Microsoft.Spark/ML/Param/Param.cs b/src/csharp/Microsoft.Spark/ML/Param/Param.cs
index 0ffc8cfa..3a9bf252 100644
--- a/src/csharp/Microsoft.Spark/ML/Param/Param.cs
+++ b/src/csharp/Microsoft.Spark/ML/Param/Param.cs
@@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
@@ -52,10 +51,7 @@ namespace Microsoft.Spark.ML.Feature.Param
{
}
- internal Param(JvmObjectReference jvmObject)
- {
- Reference = jvmObject;
- }
+ internal Param(JvmObjectReference jvmObject) => Reference = jvmObject;
public JvmObjectReference Reference { get; private set; }
diff --git a/src/csharp/Microsoft.Spark/ML/Param/ParamMap.cs b/src/csharp/Microsoft.Spark/ML/Param/ParamMap.cs
new file mode 100644
index 00000000..543ea228
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Param/ParamMap.cs
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature.Param
+{
+ ///
+ /// A param to value map.
+ ///
+ public class ParamMap : IJvmObjectReferenceProvider
+ {
+ private static readonly string s_ParamMapClassName = "org.apache.spark.ml.param.ParamMap";
+
+ ///
+ /// Creates a new instance of a
+ ///
+ public ParamMap() : this(SparkEnvironment.JvmBridge.CallConstructor(s_ParamMapClassName))
+ {
+ }
+
+ internal ParamMap(JvmObjectReference jvmObject) => Reference = jvmObject;
+
+ public JvmObjectReference Reference { get; private set; }
+
+ ///
+ /// Puts a (param, value) pair (overwrites if the input param exists).
+ ///
+ /// The param to be add
+ /// The param value to be add
+ public ParamMap Put(Param param, T value) =>
+ WrapAsParamMap((JvmObjectReference)Reference.Invoke("put", param, value));
+
+ ///
+ /// Returns the string representation of this ParamMap.
+ ///
+ /// representation as string value.
+ public override string ToString() =>
+ (string)Reference.Invoke("toString");
+
+ private static ParamMap WrapAsParamMap(object obj) =>
+ new ParamMap((JvmObjectReference)obj);
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Param/ParamPair.cs b/src/csharp/Microsoft.Spark/ML/Param/ParamPair.cs
new file mode 100644
index 00000000..359f8a7e
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Param/ParamPair.cs
@@ -0,0 +1,29 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop;
+using Microsoft.Spark.Interop.Ipc;
+
+namespace Microsoft.Spark.ML.Feature.Param
+{
+ ///
+ /// A param and its value.
+ ///
+ public sealed class ParamPair : IJvmObjectReferenceProvider
+ {
+ private static readonly string s_ParamPairClassName = "org.apache.spark.ml.param.ParamPair";
+
+ ///
+ /// Creates a new instance of a
+ ///
+ public ParamPair(Param param, T value)
+ : this(SparkEnvironment.JvmBridge.CallConstructor(s_ParamPairClassName, param, value))
+ {
+ }
+
+ internal ParamPair(JvmObjectReference jvmObject) => Reference = jvmObject;
+
+ public JvmObjectReference Reference { get; private set; }
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Util/Read.cs b/src/csharp/Microsoft.Spark/ML/Util/Read.cs
new file mode 100644
index 00000000..57260921
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Util/Read.cs
@@ -0,0 +1,67 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Linq;
+using System.Reflection;
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Sql;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// Class for utility classes that can load ML instances.
+ ///
+ /// ML instance type
+ public class JavaMLReader : IJvmObjectReferenceProvider
+ {
+ internal JavaMLReader(JvmObjectReference jvmObject) => Reference = jvmObject;
+
+ public JvmObjectReference Reference { get; private set; }
+
+ ///
+ /// Loads the ML component from the input path.
+ ///
+ /// The path the previous instance of type T was saved to
+ /// The type T instance
+ public T Load(string path) =>
+ WrapAsType((JvmObjectReference)Reference.Invoke("load", path));
+
+ /// Sets the Spark Session to use for saving/loading.
+ /// The Spark Session to be set
+ public JavaMLReader Session(SparkSession sparkSession)
+ {
+ Reference.Invoke("session", sparkSession);
+ return this;
+ }
+
+ private static T WrapAsType(JvmObjectReference reference)
+ {
+ ConstructorInfo constructor = typeof(T)
+ .GetConstructors(BindingFlags.NonPublic | BindingFlags.Instance)
+ .Single(c =>
+ {
+ ParameterInfo[] parameters = c.GetParameters();
+ return (parameters.Length == 1) &&
+ (parameters[0].ParameterType == typeof(JvmObjectReference));
+ });
+
+ return (T)constructor.Invoke(new object[] { reference });
+ }
+ }
+
+ ///
+ /// Interface for objects that provide MLReader.
+ ///
+ ///
+ /// ML instance type
+ ///
+ public interface IJavaMLReadable
+ {
+ ///
+ /// Get the corresponding JavaMLReader instance.
+ ///
+ /// an instance for this ML instance.
+ JavaMLReader Read();
+ }
+}
diff --git a/src/csharp/Microsoft.Spark/ML/Util/Write.cs b/src/csharp/Microsoft.Spark/ML/Util/Write.cs
new file mode 100644
index 00000000..8767e2ad
--- /dev/null
+++ b/src/csharp/Microsoft.Spark/ML/Util/Write.cs
@@ -0,0 +1,73 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Spark.Interop.Ipc;
+using Microsoft.Spark.Sql;
+
+namespace Microsoft.Spark.ML.Feature
+{
+ ///
+ /// Class for utility classes that can save ML instances in Spark's internal format.
+ ///
+ public class JavaMLWriter : IJvmObjectReferenceProvider
+ {
+ internal JavaMLWriter(JvmObjectReference jvmObject) => Reference = jvmObject;
+
+ public JvmObjectReference Reference { get; private set; }
+
+ /// Saves the ML instances to the input path.
+ /// The path to save the object to
+ public void Save(string path) => Reference.Invoke("save", path);
+
+ ///
+ /// save() handles overwriting and then calls this method.
+ /// Subclasses should override this method to implement the actual saving of the instance.
+ ///
+ /// The path to save the object to
+ protected void SaveImpl(string path) => Reference.Invoke("saveImpl", path);
+
+ /// Overwrites if the output path already exists.
+ public JavaMLWriter Overwrite()
+ {
+ Reference.Invoke("overwrite");
+ return this;
+ }
+
+ ///
+ /// Adds an option to the underlying MLWriter. See the documentation for the specific model's
+ /// writer for possible options. The option name (key) is case-insensitive.
+ ///
+ /// key of the option
+ /// value of the option
+ public JavaMLWriter Option(string key, string value)
+ {
+ Reference.Invoke("option", key, value);
+ return this;
+ }
+
+ /// Sets the Spark Session to use for saving/loading.
+ /// The Spark Session to be set
+ public JavaMLWriter Session(SparkSession sparkSession)
+ {
+ Reference.Invoke("session", sparkSession);
+ return this;
+ }
+ }
+
+ ///
+ /// Interface for classes that provide JavaMLWriter.
+ ///
+ public interface IJavaMLWritable
+ {
+ ///
+ /// Get the corresponding JavaMLWriter instance.
+ ///
+ /// a instance for this ML instance.
+ JavaMLWriter Write();
+
+ /// Saves this ML instance to the input path
+ /// The path to save the object to
+ void Save(string path);
+ }
+}
diff --git a/src/scala/microsoft-spark-2-4/pom.xml b/src/scala/microsoft-spark-2-4/pom.xml
index 1c959bb0..3a8531d2 100644
--- a/src/scala/microsoft-spark-2-4/pom.xml
+++ b/src/scala/microsoft-spark-2-4/pom.xml
@@ -33,6 +33,12 @@
${spark.version}
provided
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
junit
junit
diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 00000000..9f556338
--- /dev/null
+++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 00000000..3e3c3e0e
--- /dev/null
+++ b/src/scala/microsoft-spark-2-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}
diff --git a/src/scala/microsoft-spark-3-0/pom.xml b/src/scala/microsoft-spark-3-0/pom.xml
index 90294ab3..5028764a 100644
--- a/src/scala/microsoft-spark-3-0/pom.xml
+++ b/src/scala/microsoft-spark-3-0/pom.xml
@@ -33,6 +33,12 @@
${spark.version}
provided
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
junit
junit
diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 00000000..9f556338
--- /dev/null
+++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 00000000..3e3c3e0e
--- /dev/null
+++ b/src/scala/microsoft-spark-3-0/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}
diff --git a/src/scala/microsoft-spark-3-1/pom.xml b/src/scala/microsoft-spark-3-1/pom.xml
index 8eb58b44..23be8a15 100644
--- a/src/scala/microsoft-spark-3-1/pom.xml
+++ b/src/scala/microsoft-spark-3-1/pom.xml
@@ -33,6 +33,12 @@
${spark.version}
provided
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
junit
junit
diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 00000000..9f556338
--- /dev/null
+++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 00000000..3e3c3e0e
--- /dev/null
+++ b/src/scala/microsoft-spark-3-1/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}
diff --git a/src/scala/microsoft-spark-3-2/pom.xml b/src/scala/microsoft-spark-3-2/pom.xml
index b6efeb17..d3dd318f 100644
--- a/src/scala/microsoft-spark-3-2/pom.xml
+++ b/src/scala/microsoft-spark-3-2/pom.xml
@@ -33,6 +33,12 @@
${spark.version}
provided
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${spark.version}
+ provided
+
junit
junit
diff --git a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
new file mode 100644
index 00000000..9f556338
--- /dev/null
+++ b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.api.dotnet
+
+import scala.collection.JavaConverters._
+
+/** DotnetUtils object that hosts some helper functions
+ * help data type conversions between dotnet and scala
+ */
+object DotnetUtils {
+
+ /** A helper function to convert scala Map to java.util.Map
+ * @param value - scala Map
+ * @return java.util.Map
+ */
+ def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava
+
+ /** Convert java data type to corresponding scala type
+ * @param value - java.lang.Object
+ * @return Any
+ */
+ def mapScalaToJava(value: java.lang.Object): Any = {
+ value match {
+ case i: java.lang.Integer => i.toInt
+ case d: java.lang.Double => d.toDouble
+ case f: java.lang.Float => f.toFloat
+ case b: java.lang.Boolean => b.booleanValue()
+ case l: java.lang.Long => l.toLong
+ case s: java.lang.Short => s.toShort
+ case by: java.lang.Byte => by.toByte
+ case c: java.lang.Character => c.toChar
+ case _ => value
+ }
+ }
+}
diff --git a/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
new file mode 100644
index 00000000..3e3c3e0e
--- /dev/null
+++ b/src/scala/microsoft-spark-3-2/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala
@@ -0,0 +1,26 @@
+
+/*
+ * Licensed to the .NET Foundation under one or more agreements.
+ * The .NET Foundation licenses this file to you under the MIT license.
+ * See the LICENSE file in the project root for more information.
+ */
+
+package org.apache.spark.mllib.api.dotnet
+
+import org.apache.spark.ml._
+import scala.collection.JavaConverters._
+
+/** MLUtils object that hosts helper functions
+ * related to ML usage
+ */
+object MLUtils {
+
+ /** A helper function to let pipeline accept java.util.ArrayList
+ * format stages in scala code
+ * @param pipeline - The pipeline to be set stages
+ * @param value - A java.util.ArrayList of PipelineStages to be set as stages
+ * @return The pipeline
+ */
+ def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline =
+ pipeline.setStages(value.asScala.toArray)
+}