Skip to content

Commit

Permalink
Add assembly check when pipeline trying to get stages (dotnet#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
serena-ruan authored May 6, 2022
1 parent c2ee198 commit c96efde
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 94 deletions.
25 changes: 0 additions & 25 deletions src/csharp/Microsoft.Spark/ML/Feature/Base.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,4 @@ protected static T WrapAsType<T>(JvmObjectReference reference)
return (T)constructor.Invoke(new object[] { reference });
}
}

/// <summary>
/// DotnetUtils is used to hold basic general helper functions that
/// are used within ML scope.
/// </summary>
internal class DotnetUtils
{
/// <summary>
/// Helper function for getting the exact class name from jvm object.
/// </summary>
/// <param name="jvmObject">The reference to object created in JVM.</param>
/// <returns>A string Tuple2 of constructor class name and method name</returns>
internal static (string, string) GetUnderlyingType(JvmObjectReference jvmObject)
{
var jvmClass = (JvmObjectReference)jvmObject.Invoke("getClass");
var returnClass = (string)jvmClass.Invoke("getTypeName");
string[] dotnetClass = returnClass.Replace("com.microsoft.azure.synapse.ml", "Synapse.ML")
.Replace("org.apache.spark.ml", "Microsoft.Spark.ML")
.Split(".".ToCharArray());
string[] renameClass = dotnetClass.Select(x => char.ToUpper(x[0]) + x.Substring(1)).ToArray();
string constructorClass = string.Join(".", renameClass);
string methodName = "WrapAs" + dotnetClass[dotnetClass.Length - 1];
return (constructorClass, methodName);
}
}
}
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/Bucketizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ public class Bucketizer :
IJavaMLWritable,
IJavaMLReadable<Bucketizer>
{
private static readonly string s_bucketizerClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.Bucketizer";

/// <summary>
/// Create a <see cref="Bucketizer"/> without any parameters
/// </summary>
public Bucketizer() : base(s_bucketizerClassName)
public Bucketizer() : base(s_className)
{
}

Expand All @@ -39,7 +39,7 @@ public Bucketizer() : base(s_bucketizerClassName)
/// <see cref="Bucketizer"/> a unique ID
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public Bucketizer(string uid) : base(s_bucketizerClassName, uid)
public Bucketizer(string uid) : base(s_className, uid)
{
}

Expand Down Expand Up @@ -163,7 +163,7 @@ public Bucketizer SetOutputCols(List<string> value) =>
public static Bucketizer Load(string path) =>
WrapAsBucketizer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_bucketizerClassName, "load", path));
s_className, "load", path));

/// <summary>
/// Executes the <see cref="Bucketizer"/> and transforms the DataFrame to include the new
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/CountVectorizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ public class CountVectorizer :
IJavaMLWritable,
IJavaMLReadable<CountVectorizer>
{
private static readonly string s_countVectorizerClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.CountVectorizer";

/// <summary>
/// Creates a <see cref="CountVectorizer"/> without any parameters.
/// </summary>
public CountVectorizer() : base(s_countVectorizerClassName)
public CountVectorizer() : base(s_className)
{
}

Expand All @@ -28,7 +28,7 @@ public CountVectorizer() : base(s_countVectorizerClassName)
/// <see cref="CountVectorizer"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public CountVectorizer(string uid) : base(s_countVectorizerClassName, uid)
public CountVectorizer(string uid) : base(s_className, uid)
{
}

Expand All @@ -52,7 +52,7 @@ public override CountVectorizerModel Fit(DataFrame dataFrame) =>
public static CountVectorizer Load(string path) =>
WrapAsCountVectorizer((JvmObjectReference)
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_countVectorizerClassName, "load", path));
s_className, "load", path));

/// <summary>
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/CountVectorizerModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class CountVectorizerModel :
IJavaMLWritable,
IJavaMLReadable<CountVectorizerModel>
{
private static readonly string s_countVectorizerModelClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.CountVectorizerModel";

/// <summary>
Expand All @@ -24,7 +24,7 @@ public class CountVectorizerModel :
/// <param name="vocabulary">The vocabulary to use</param>
public CountVectorizerModel(List<string> vocabulary)
: this(SparkEnvironment.JvmBridge.CallConstructor(
s_countVectorizerModelClassName, vocabulary))
s_className, vocabulary))
{
}

Expand All @@ -36,7 +36,7 @@ public CountVectorizerModel(List<string> vocabulary)
/// <param name="vocabulary">The vocabulary to use</param>
public CountVectorizerModel(string uid, List<string> vocabulary)
: this(SparkEnvironment.JvmBridge.CallConstructor(
s_countVectorizerModelClassName, uid, vocabulary))
s_className, uid, vocabulary))
{
}

Expand All @@ -54,7 +54,7 @@ internal CountVectorizerModel(JvmObjectReference jvmObject) : base(jvmObject)
public static CountVectorizerModel Load(string path) =>
WrapAsCountVectorizerModel(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_countVectorizerModelClassName, "load", path));
s_className, "load", path));

/// <summary>
/// Gets the binary toggle to control the output vector values. If True, all nonzero counts
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/FeatureHasher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ public class FeatureHasher :
IJavaMLWritable,
IJavaMLReadable<FeatureHasher>
{
private static readonly string s_featureHasherClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.FeatureHasher";

/// <summary>
/// Creates a <see cref="FeatureHasher"/> without any parameters.
/// </summary>
public FeatureHasher() : base(s_featureHasherClassName)
public FeatureHasher() : base(s_className)
{
}

Expand All @@ -30,7 +30,7 @@ public FeatureHasher() : base(s_featureHasherClassName)
/// <see cref="FeatureHasher"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public FeatureHasher(string uid) : base(s_featureHasherClassName, uid)
public FeatureHasher(string uid) : base(s_className, uid)
{
}

Expand All @@ -48,7 +48,7 @@ internal FeatureHasher(JvmObjectReference jvmObject) : base(jvmObject)
public static FeatureHasher Load(string path) =>
WrapAsFeatureHasher(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_featureHasherClassName,
s_className,
"load",
path));

Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/HashingTF.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ public class HashingTF :
IJavaMLWritable,
IJavaMLReadable<HashingTF>
{
private static readonly string s_hashingTfClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.HashingTF";

/// <summary>
/// Create a <see cref="HashingTF"/> without any parameters
/// </summary>
public HashingTF() : base(s_hashingTfClassName)
public HashingTF() : base(s_className)
{
}

Expand All @@ -36,7 +36,7 @@ public HashingTF() : base(s_hashingTfClassName)
/// <see cref="HashingTF"/> a unique ID
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public HashingTF(string uid) : base(s_hashingTfClassName, uid)
public HashingTF(string uid) : base(s_className, uid)
{
}

Expand All @@ -52,7 +52,7 @@ internal HashingTF(JvmObjectReference jvmObject) : base(jvmObject)
public static HashingTF Load(string path) =>
WrapAsHashingTF(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_hashingTfClassName, "load", path));
s_className, "load", path));

/// <summary>
/// Gets the binary toggle that controls term frequency counts
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/IDF.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ public class IDF :
IJavaMLWritable,
IJavaMLReadable<IDF>
{
private static readonly string s_IDFClassName = "org.apache.spark.ml.feature.IDF";
private static readonly string s_className = "org.apache.spark.ml.feature.IDF";

/// <summary>
/// Create a <see cref="IDF"/> without any parameters
/// </summary>
public IDF() : base(s_IDFClassName)
public IDF() : base(s_className)
{
}

Expand All @@ -36,7 +36,7 @@ public IDF() : base(s_IDFClassName)
/// <see cref="IDF"/> a unique ID
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public IDF(string uid) : base(s_IDFClassName, uid)
public IDF(string uid) : base(s_className, uid)
{
}

Expand Down Expand Up @@ -103,7 +103,7 @@ public override IDFModel Fit(DataFrame source) =>
public static IDF Load(string path)
{
return WrapAsIDF(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_IDFClassName, "load", path));
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_className, "load", path));
}

/// <summary>
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/IDFModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ public class IDFModel :
IJavaMLWritable,
IJavaMLReadable<IDFModel>
{
private static readonly string s_IDFModelClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.IDFModel";

/// <summary>
/// Create a <see cref="IDFModel"/> without any parameters
/// </summary>
public IDFModel() : base(s_IDFModelClassName)
public IDFModel() : base(s_className)
{
}

Expand All @@ -32,7 +32,7 @@ public IDFModel() : base(s_IDFModelClassName)
/// <see cref="IDFModel"/> a unique ID
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public IDFModel(string uid) : base(s_IDFModelClassName, uid)
public IDFModel(string uid) : base(s_className, uid)
{
}

Expand Down Expand Up @@ -96,7 +96,7 @@ public static IDFModel Load(string path)
{
return WrapAsIDFModel(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_IDFModelClassName, "load", path));
s_className, "load", path));
}

/// <summary>
Expand Down
8 changes: 4 additions & 4 deletions src/csharp/Microsoft.Spark/ML/Feature/NGram.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ public class NGram :
IJavaMLWritable,
IJavaMLReadable<NGram>
{
private static readonly string s_nGramClassName =
private static readonly string s_className =
"org.apache.spark.ml.feature.NGram";

/// <summary>
/// Create a <see cref="NGram"/> without any parameters.
/// </summary>
public NGram() : base(s_nGramClassName)
public NGram() : base(s_className)
{
}

Expand All @@ -35,7 +35,7 @@ public NGram() : base(s_nGramClassName)
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.
/// </param>
public NGram(string uid) : base(s_nGramClassName, uid)
public NGram(string uid) : base(s_className, uid)
{
}

Expand Down Expand Up @@ -123,7 +123,7 @@ public override StructType TransformSchema(StructType value) =>
public static NGram Load(string path) =>
WrapAsNGram(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_nGramClassName,
s_className,
"load",
path));

Expand Down
37 changes: 24 additions & 13 deletions src/csharp/Microsoft.Spark/ML/Feature/Pipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Reflection;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Utils;
using System.Collections.Generic;

namespace Microsoft.Spark.ML.Feature
{
Expand All @@ -26,12 +27,12 @@ public class Pipeline :
IJavaMLWritable,
IJavaMLReadable<Pipeline>
{
private static readonly string s_pipelineClassName = "org.apache.spark.ml.Pipeline";
private static readonly string s_className = "org.apache.spark.ml.Pipeline";

/// <summary>
/// Creates a <see cref="Pipeline"/> without any parameters.
/// </summary>
public Pipeline() : base(s_pipelineClassName)
public Pipeline() : base(s_className)
{
}

Expand All @@ -40,7 +41,7 @@ public Pipeline() : base(s_pipelineClassName)
/// <see cref="Pipeline"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public Pipeline(string uid) : base(s_pipelineClassName, uid)
public Pipeline(string uid) : base(s_className, uid)
{
}

Expand All @@ -57,24 +58,34 @@ internal Pipeline(JvmObjectReference jvmObject) : base(jvmObject)
/// <returns><see cref="Pipeline"/> object</returns>
public Pipeline SetStages(JavaPipelineStage[] value) =>
WrapAsPipeline((JvmObjectReference)SparkEnvironment.JvmBridge.CallStaticJavaMethod(
"org.apache.spark.mllib.api.dotnet.MLUtils", "setPipelineStages",
Reference, value.ToJavaArrayList()));
"org.apache.spark.mllib.api.dotnet.MLUtils",
"setPipelineStages",
Reference,
value.ToJavaArrayList()));

/// <summary>
/// Get the stages of pipeline instance.
/// </summary>
/// <returns>A sequence of <see cref="JavaPipelineStage"/> stages</returns>
public JavaPipelineStage[] GetStages()
{
JvmObjectReference[] jvmObjects = (JvmObjectReference[])Reference.Invoke("getStages");
JavaPipelineStage[] result = new JavaPipelineStage[jvmObjects.Length];
var jvmObjects = (JvmObjectReference[])Reference.Invoke("getStages");
var result = new JavaPipelineStage[jvmObjects.Length];
Dictionary<string, Type> classMapping = JvmObjectUtils.ConstructJavaClassMapping(
typeof(JavaPipelineStage),
"s_className");

for (int i = 0; i < jvmObjects.Length; i++)
{
(string constructorClass, string methodName) = DotnetUtils.GetUnderlyingType(jvmObjects[i]);
Type type = Type.GetType(constructorClass);
MethodInfo method = type.GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Static);
result[i] = (JavaPipelineStage)method.Invoke(null, new object[] { jvmObjects[i] });
if (JvmObjectUtils.TryConstructInstanceFromJvmObject(
jvmObjects[i],
classMapping,
out JavaPipelineStage instance))
{
result[i] = instance;
}
}

return result;
}

Expand All @@ -91,7 +102,7 @@ override public PipelineModel Fit(DataFrame dataset) =>
/// <param name="path">The path the previous <see cref="Pipeline"/> was saved to</param>
/// <returns>New <see cref="Pipeline"/> object, loaded from path.</returns>
public static Pipeline Load(string path) => WrapAsPipeline(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_pipelineClassName, "load", path));
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_className, "load", path));

/// <summary>
/// Saves the object so that it can be loaded later using Load. Note that these objects
Expand Down
Loading

0 comments on commit c96efde

Please sign in to comment.