зеркало из https://github.com/microsoft/SynapseML.git
fix: companionModelClassName no longer returns generic type variable (#2195)
* fix: companionModelClassName no longer returns generic type variable * Move TestRegressor to its own file * added fuzzing exclusions * fixing style * removed comma ---------
This commit is contained in:
Родитель
580f7594e0
Коммит
50c7f1ebbd
|
@ -11,10 +11,9 @@ import org.apache.spark.ml.evaluation.Evaluator
|
|||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.{Estimator, Model, Transformer}
|
||||
|
||||
import java.lang.reflect.ParameterizedType
|
||||
import scala.reflect.runtime.universe._
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Files
|
||||
import scala.collection.Iterator.iterate
|
||||
|
||||
|
||||
trait BaseWrappable extends Params {
|
||||
|
@ -28,19 +27,24 @@ trait BaseWrappable extends Params {
|
|||
|
||||
protected lazy val classNameHelper: String = thisStage.getClass.getName.split(".".toCharArray).last
|
||||
|
||||
|
||||
protected def companionModelClassName: String = {
|
||||
val superClass = iterate[Class[_]](thisStage.getClass)(_.getSuperclass)
|
||||
.find(c => Set("Estimator", "ProbabilisticClassifier", "Predictor", "BaseRegressor", "Ranker")(
|
||||
c.getSuperclass.getSimpleName))
|
||||
.get
|
||||
val typeArgs = superClass.getGenericSuperclass.asInstanceOf[ParameterizedType].getActualTypeArguments
|
||||
val modelTypeArg = superClass.getSuperclass.getSimpleName match {
|
||||
case "Estimator" =>
|
||||
typeArgs.head
|
||||
case model if Set("ProbabilisticClassifier", "BaseRegressor", "Predictor", "Ranker")(model) =>
|
||||
typeArgs.last
|
||||
val symbol = scala.reflect.runtime.currentMirror.classSymbol(thisStage.getClass)
|
||||
|
||||
val superClassSymbol = symbol.baseClasses
|
||||
.find(s => Set("Estimator", "ProbabilisticClassifier", "Predictor", "BaseRegressor", "Ranker")
|
||||
.contains(s.name.toString))
|
||||
.getOrElse(throw new NoSuchElementException("Matching superclass was not found: " + symbol.baseClasses))
|
||||
|
||||
val typeArgs = symbol.toType.baseType(superClassSymbol).typeArgs
|
||||
|
||||
val modelTypeArg = superClassSymbol.name.toString match {
|
||||
case "Estimator" => typeArgs.head
|
||||
case _ => typeArgs.last
|
||||
}
|
||||
modelTypeArg.getTypeName
|
||||
|
||||
val modelName = modelTypeArg.typeSymbol.asClass.fullName
|
||||
modelName
|
||||
}
|
||||
|
||||
def getParamInfo(p: Param[_]): ParamInfo[_] = {
|
||||
|
|
|
@ -32,7 +32,13 @@ object JarLoadingUtils {
|
|||
def instantiateServices[T: ClassTag](instantiate: Class[_] => Any, jarName: Option[String]): List[T] = {
|
||||
AllClasses
|
||||
.filter(classTag[T].runtimeClass.isAssignableFrom(_))
|
||||
.filter(c => jarName.forall(c.getResource(c.getSimpleName + ".class").toString.contains(_)))
|
||||
.filter(c => jarName.forall({
|
||||
val jarResource = c.getResource(c.getSimpleName + ".class")
|
||||
if (jarResource == null) {
|
||||
throw new IOException(s"Could not find resource for class ${c.getSimpleName}")
|
||||
}
|
||||
jarResource.toString.contains(_)
|
||||
}))
|
||||
.filter(clazz => !Modifier.isAbstract(clazz.getModifiers))
|
||||
.map(instantiate(_)).asInstanceOf[List[T]]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.codegen
|
||||
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.regression.{RegressionModel, Regressor}
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
private[codegen] class TestRegressorModel extends RegressionModel[Vector, TestRegressorModel] {
|
||||
override def predict(features: Vector): Double = 0.0
|
||||
|
||||
override def copy(extra: ParamMap): TestRegressorModel = defaultCopy(extra)
|
||||
|
||||
override val uid: String = "test"
|
||||
}
|
||||
|
||||
private[codegen] class TestRegressor extends Regressor[Vector, TestRegressor, TestRegressorModel] with Wrappable {
|
||||
override def copy(extra: ParamMap): TestRegressor = defaultCopy(extra)
|
||||
|
||||
override protected def train(dataset: Dataset[_]): TestRegressorModel = new TestRegressorModel()
|
||||
|
||||
def getCompanionModelClassName(): String = this.companionModelClassName
|
||||
|
||||
override val uid: String = "test"
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.codegen
|
||||
|
||||
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.regression.{RegressionModel, Regressor}
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
class WrappableTests extends TestBase {
|
||||
|
||||
test ("test CompanionModelClassName") {
|
||||
val regressorCompanionModelClasName = new TestRegressor().getCompanionModelClassName
|
||||
assert(regressorCompanionModelClasName.equals(
|
||||
"com.microsoft.azure.synapse.ml.codegen.WrappableTests.TestRegressorModel"))
|
||||
}
|
||||
}
|
|
@ -3,26 +3,17 @@
|
|||
|
||||
package com.microsoft.azure.synapse.ml.core.test.fuzzing
|
||||
|
||||
import com.microsoft.azure.synapse.ml.Secrets
|
||||
import com.microsoft.azure.synapse.ml.build.BuildInfo
|
||||
import com.microsoft.azure.synapse.ml.services.{HasAADToken, HasSubscriptionKey}
|
||||
import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasInputCol, HasLabelCol, HasOutputCol}
|
||||
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
|
||||
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
|
||||
import com.microsoft.azure.synapse.ml.core.utils.JarLoadingUtils
|
||||
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
|
||||
import com.microsoft.azure.synapse.ml.services.{HasAADToken, HasSubscriptionKey}
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.{MLReadable, MLWritable}
|
||||
|
||||
import java.io.File
|
||||
import java.lang.reflect.ParameterizedType
|
||||
import java.nio.charset.MalformedInputException
|
||||
import java.nio.file.Files
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.io.Source
|
||||
import scala.language.existentials
|
||||
import scala.util.matching.Regex
|
||||
|
||||
/** Tests to validate fuzzing of modules. */
|
||||
class FuzzingTest extends TestBase {
|
||||
|
@ -79,8 +70,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRankerModel",
|
||||
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
|
||||
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
|
||||
"com.microsoft.azure.synapse.ml.automl.BestModel" //TODO add proper interfaces to all of these
|
||||
|
||||
"com.microsoft.azure.synapse.ml.automl.BestModel", //TODO add proper interfaces to all of these
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
|
||||
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
|
||||
|
@ -135,7 +127,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.services.DetectMultivariateAnomaly",
|
||||
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
|
||||
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
|
||||
"com.microsoft.azure.synapse.ml.vw.VowpalWabbitRegressionModel"
|
||||
"com.microsoft.azure.synapse.ml.vw.VowpalWabbitRegressionModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
|
||||
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
|
||||
|
@ -187,7 +181,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel",
|
||||
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
|
||||
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
|
||||
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics"
|
||||
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
|
||||
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
|
||||
|
@ -241,7 +237,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel",
|
||||
"com.microsoft.azure.synapse.ml.services.form.FormOntologyTransformer",
|
||||
"com.microsoft.azure.synapse.ml.services.anomaly.SimpleDetectMultivariateAnomaly",
|
||||
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics"
|
||||
"com.microsoft.azure.synapse.ml.train.ComputePerInstanceStatistics",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
val applicableStages = pipelineStages.filter(t => !exemptions(t.getClass.getName))
|
||||
val applicableClasses = applicableStages.map(_.getClass.asInstanceOf[Class[_]]).toSet
|
||||
|
@ -334,7 +332,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.exploratory.AggregateBalanceMeasure",
|
||||
"com.microsoft.azure.synapse.ml.exploratory.DistributionBalanceMeasure",
|
||||
"com.microsoft.azure.synapse.ml.exploratory.FeatureBalanceMeasure",
|
||||
"com.microsoft.azure.synapse.ml.isolationforest.IsolationForestModel"
|
||||
"com.microsoft.azure.synapse.ml.isolationforest.IsolationForestModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
|
||||
pipelineStages.foreach { stage =>
|
||||
|
@ -356,7 +356,9 @@ class FuzzingTest extends TestBase {
|
|||
"com.microsoft.azure.synapse.ml.core.serialize.ComplexParamTest",
|
||||
"com.microsoft.azure.synapse.ml.core.serialize.MixedParamTest",
|
||||
"com.microsoft.azure.synapse.ml.core.serialize.StandardParamTest",
|
||||
"com.microsoft.azure.synapse.ml.core.serialize.TestEstimatorBase"
|
||||
"com.microsoft.azure.synapse.ml.core.serialize.TestEstimatorBase",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressorModel",
|
||||
"com.microsoft.azure.synapse.ml.codegen.TestRegressor"
|
||||
)
|
||||
val clazz = classOf[SynapseMLLogging]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче