fix: companionModelClassName no longer returns generic type variable (#2195)

* fix: companionModelClassName no longer returns generic type variable

* Move TestRegressor to its own file

* added fuzzing exclusions

* fixing style

* removed comma

---------
This commit is contained in:
Brendan Walsh 2024-04-09 13:15:50 -07:00 коммит произвёл GitHub
Родитель 580f7594e0
Коммит 50c7f1ebbd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 91 добавлений и 31 удалений

Просмотреть файл

@ -11,10 +11,9 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.{Estimator, Model, Transformer}
import java.lang.reflect.ParameterizedType
import scala.reflect.runtime.universe._
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import scala.collection.Iterator.iterate
trait BaseWrappable extends Params {
@ -28,19 +27,24 @@ trait BaseWrappable extends Params {
protected lazy val classNameHelper: String = thisStage.getClass.getName.split(".".toCharArray).last
protected def companionModelClassName: String = {
val superClass = iterate[Class[_]](thisStage.getClass)(_.getSuperclass)
.find(c => Set("Estimator", "ProbabilisticClassifier", "Predictor", "BaseRegressor", "Ranker")(
c.getSuperclass.getSimpleName))
.get
val typeArgs = superClass.getGenericSuperclass.asInstanceOf[ParameterizedType].getActualTypeArguments
val modelTypeArg = superClass.getSuperclass.getSimpleName match {
case "Estimator" =>
typeArgs.head
case model if Set("ProbabilisticClassifier", "BaseRegressor", "Predictor", "Ranker")(model) =>
typeArgs.last
val symbol = scala.reflect.runtime.currentMirror.classSymbol(thisStage.getClass)
val superClassSymbol = symbol.baseClasses
.find(s => Set("Estimator", "ProbabilisticClassifier", "Predictor", "BaseRegressor", "Ranker")
.contains(s.name.toString))
.getOrElse(throw new NoSuchElementException("Matching superclass was not found: " + symbol.baseClasses))
val typeArgs = symbol.toType.baseType(superClassSymbol).typeArgs
val modelTypeArg = superClassSymbol.name.toString match {
case "Estimator" => typeArgs.head
case _ => typeArgs.last
}
modelTypeArg.getTypeName
val modelName = modelTypeArg.typeSymbol.asClass.fullName
modelName
}
def getParamInfo(p: Param[_]): ParamInfo[_] = {

Просмотреть файл

@ -32,7 +32,13 @@ object JarLoadingUtils {
def instantiateServices[T: ClassTag](instantiate: Class[_] => Any, jarName: Option[String]): List[T] = {
AllClasses
.filter(classTag[T].runtimeClass.isAssignableFrom(_))
.filter(c => jarName.forall(c.getResource(c.getSimpleName + ".class").toString.contains(_)))
.filter(c => jarName.forall({
val jarResource = c.getResource(c.getSimpleName + ".class")
if (jarResource == null) {
throw new IOException(s"Could not find resource for class ${c.getSimpleName}")
}
jarResource.toString.contains(_)
}))
.filter(clazz => !Modifier.isAbstract(clazz.getModifiers))
.map(instantiate(_)).asInstanceOf[List[T]]
}

Просмотреть файл

@ -0,0 +1,28 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.codegen
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.{RegressionModel, Regressor}
import org.apache.spark.sql.Dataset
private[codegen] class TestRegressorModel extends RegressionModel[Vector, TestRegressorModel] {
override def predict(features: Vector): Double = 0.0
override def copy(extra: ParamMap): TestRegressorModel = defaultCopy(extra)
override val uid: String = "test"
}
private[codegen] class TestRegressor extends Regressor[Vector, TestRegressor, TestRegressorModel] with Wrappable {
override def copy(extra: ParamMap): TestRegressor = defaultCopy(extra)
override protected def train(dataset: Dataset[_]): TestRegressorModel = new TestRegressorModel()
def getCompanionModelClassName(): String = this.companionModelClassName
override val uid: String = "test"
}

Просмотреть файл

@ -0,0 +1,20 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.codegen
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.{RegressionModel, Regressor}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset
class WrappableTests extends TestBase {
test ("test CompanionModelClassName") {
val regressorCompanionModelClasName = new TestRegressor().getCompanionModelClassName
assert(regressorCompanionModelClasName.equals(
"com.microsoft.azure.synapse.ml.codegen.WrappableTests.TestRegressorModel"))
}
}

Просмотреть файл

@ -3,26 +3,17 @@
package com.microsoft.azure.synapse.ml.core.test.fuzzing
import com.microsoft.azure.synapse.ml.Secrets
import com.microsoft.azure.synapse.ml.build.BuildInfo
import com.microsoft.azure.synapse.ml.services.{HasAADToken, HasSubscriptionKey}
import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasInputCol, HasLabelCol, HasOutputCol}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.services.{HasAADToken, HasSubscriptionKey}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.{MLReadable, MLWritable}
import java.io.File
import java.lang.reflect.ParameterizedType
import java.nio.charset.MalformedInputException
import java.nio.file.Files
import scala.collection.JavaConverters._
import scala.io.Source
import scala.language.existentials
import scala.util.matching.Regex
/** Tests to validate fuzzing of modules. */
class FuzzingTest extends TestBase {
@ -79,8 +70,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRankerModel",
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.automl.BestModel" //TODO add proper interfaces to all of these
"com.microsoft.azure.synapse.ml.automl.BestModel", //TODO add proper interfaces to all of these
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
@ -135,7 +127,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.services.DetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.vw.VowpalWabbitRegressionModel"
"com.microsoft.azure.synapse.ml.vw.VowpalWabbitRegressionModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
@ -187,7 +181,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel",
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics"
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics",
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
@ -241,7 +237,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel",
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics"
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics",
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
@ -334,7 +332,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.exploratory.AggregateBalanceMeasure",
"com.microsoft.azure.synapse.ml.exploratory.DistributionBalanceMeasure",
"com.microsoft.azure.synapse.ml.exploratory.FeatureBalanceMeasure",
"com.microsoft.azure.synapse.ml.isolationforest.IsolationForestModel"
"com.microsoft.azure.synapse.ml.isolationforest.IsolationForestModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
pipelineStages.foreach { stage =>
@ -356,7 +356,9 @@ class FuzzingTest extends TestBase {
"com.microsoft.azure.synapse.ml.core.serialize.ComplexParamTest",
"com.microsoft.azure.synapse.ml.core.serialize.MixedParamTest",
"com.microsoft.azure.synapse.ml.core.serialize.StandardParamTest",
"com.microsoft.azure.synapse.ml.core.serialize.TestEstimatorBase"
"com.microsoft.azure.synapse.ml.core.serialize.TestEstimatorBase",
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
)
val clazz = classOf[SynapseMLLogging]