зеркало из https://github.com/microsoft/SynapseML.git
feat: Reference dataset (#1977)
* Add pre-calculated reference dataset * First commit * build fixes * fix dataset by samples count * fixes * complex params writeable * remove numClasses from deserialization * add feature names * fix by using totalNumRows * temp remove tests * turn on generic tests * refactor tests to smaller batches * turn on more tests * fix benchmark file * moved tests around * responded to comments * more fixes * more fixes * add reference dataset test * fix and add logs * fix ref dataset test
This commit is contained in:
Родитель
b23f050599
Коммит
5c93e4eff5
|
@ -426,7 +426,7 @@ lazy val deepLearning = (project in file("deep-learning"))
|
|||
lazy val lightgbm = (project in file("lightgbm"))
|
||||
.dependsOn(core % "test->test;compile->compile")
|
||||
.settings(settings ++ Seq(
|
||||
libraryDependencies += ("com.microsoft.ml.lightgbm" % "lightgbmlib" % "3.3.500"),
|
||||
libraryDependencies += ("com.microsoft.ml.lightgbm" % "lightgbmlib" % "3.3.510"),
|
||||
name := "synapseml-lightgbm"
|
||||
): _*)
|
||||
|
||||
|
|
|
@ -68,7 +68,6 @@ class BulkPartitionTask extends BasePartitionTask {
|
|||
val datasetInner: LightGBMDataset = ac.generateDataset(ctx, referenceDataset)
|
||||
ctx.trainingCtx.columnParams.groupColumn.foreach(_ => datasetInner.addGroupColumn(ac.getGroups))
|
||||
datasetInner.setFeatureNames(ctx.trainingCtx.featureNames, ac.getNumCols)
|
||||
datasetInner
|
||||
} finally {
|
||||
ac.cleanup()
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.lightgbm
|
|||
import com.microsoft.azure.synapse.ml.core.utils.{ClusterUtil, ParamsStringBuilder}
|
||||
import com.microsoft.azure.synapse.ml.io.http.SharedSingleton
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.dataset.DatasetUtils
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.dataset._
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.params._
|
||||
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
|
@ -14,9 +14,8 @@ import org.apache.spark.ml.attribute._
|
|||
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol => HasFeaturesColSpark, HasLabelCol => HasLabelColSpark}
|
||||
import org.apache.spark.ml.{Estimator, Model}
|
||||
import org.apache.spark.ml.{ComplexParamsWritable, Estimator, Model}
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
|
@ -24,8 +23,10 @@ import scala.language.existentials
|
|||
import scala.math.min
|
||||
import scala.util.matching.Regex
|
||||
|
||||
trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[TrainedModel]
|
||||
with LightGBMParams with HasFeaturesColSpark with HasLabelColSpark with LightGBMPerformance with SynapseMLLogging {
|
||||
// scalastyle:off file.size.limit
|
||||
trait LightGBMBase[TrainedModel <: Model[TrainedModel] with LightGBMModelParams] extends Estimator[TrainedModel]
|
||||
with LightGBMParams with ComplexParamsWritable
|
||||
with HasFeaturesColSpark with HasLabelColSpark with LightGBMPerformance with SynapseMLLogging {
|
||||
|
||||
/** Trains the LightGBM model. If batches are specified, breaks training dataset into batches for training.
|
||||
*
|
||||
|
@ -196,7 +197,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
.union(categoricalIndexes).distinct
|
||||
}
|
||||
|
||||
def getSlotNamesWithMetadata(featuresSchema: StructField): Option[Array[String]] = {
|
||||
private def getSlotNamesWithMetadata(featuresSchema: StructField): Option[Array[String]] = {
|
||||
if (getSlotNames.nonEmpty) {
|
||||
Some(getSlotNames)
|
||||
} else {
|
||||
|
@ -311,7 +312,9 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
ExecutionParams(getChunkSize,
|
||||
getMatrixType,
|
||||
execNumThreads,
|
||||
getExecutionMode,
|
||||
getDataTransferMode,
|
||||
getSamplingMode,
|
||||
getSamplingSubsetSize,
|
||||
getMicroBatchSize,
|
||||
getUseSingleDatasetMode,
|
||||
getMaxStreamingOMPThreads)
|
||||
|
@ -368,7 +371,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
get(maxCatToOnehot))
|
||||
}
|
||||
|
||||
def getDatasetCreationParams(categoricalIndexes: Array[Int], numThreads: Int): String = {
|
||||
protected def getDatasetCreationParams(categoricalIndexes: Array[Int], numThreads: Int): String = {
|
||||
new ParamsStringBuilder(prefix = "", delimiter = "=")
|
||||
.appendParamValueIfNotThere("is_pre_partition", Option("True"))
|
||||
.appendParamValueIfNotThere("max_bin", Option(getMaxBin))
|
||||
|
@ -390,7 +393,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
* @param batchIndex In running in batch training mode, gets the batch number.
|
||||
* @return The LightGBM Model from the trained LightGBM Booster.
|
||||
*/
|
||||
protected def trainOneDataBatch(dataset: Dataset[_], batchIndex: Int, batchCount: Int): TrainedModel = {
|
||||
private def trainOneDataBatch(dataset: Dataset[_], batchIndex: Int, batchCount: Int): TrainedModel = {
|
||||
val measures = new InstrumentationMeasures()
|
||||
setBatchPerformanceMeasure(batchIndex, measures)
|
||||
|
||||
|
@ -415,17 +418,24 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
val trainParams = addCustomTrainParams(generalTrainParams, dataset)
|
||||
log.info(s"LightGBM batch $batchIndex of $batchCount, parameters: ${trainParams.toString()}")
|
||||
|
||||
val isStreamingMode = getExecutionMode == LightGBMConstants.StreamingExecutionMode
|
||||
val (broadcastedSampleData: Option[Broadcast[Array[Row]]], partitionCounts: Option[Array[Long]]) =
|
||||
val isStreamingMode = getDataTransferMode == LightGBMConstants.StreamingDataTransferMode
|
||||
val (serializedReferenceDataset: Option[Array[Byte]], partitionCounts: Option[Array[Long]]) =
|
||||
if (isStreamingMode) {
|
||||
val (sampledData, partitionCounts) = calculateRowStatistics(trainingData, trainParams, numCols, measures)
|
||||
(Some(sc.broadcast(sampledData)), Some(partitionCounts))
|
||||
val (referenceDataset, partitionCounts) =
|
||||
calculateRowStatistics(trainingData, trainParams, numCols, measures)
|
||||
|
||||
// Save the reference Dataset so it's available to client and other batches
|
||||
if (getReferenceDataset.isEmpty) {
|
||||
log.info(s"Saving reference dataset of length: ${referenceDataset.length}")
|
||||
setReferenceDataset(referenceDataset)
|
||||
}
|
||||
(Some(referenceDataset), Some(partitionCounts))
|
||||
} else (None, None)
|
||||
|
||||
validateSlotNames(featuresSchema)
|
||||
executeTraining(preprocessedDF,
|
||||
validationData,
|
||||
broadcastedSampleData,
|
||||
serializedReferenceDataset,
|
||||
partitionCounts,
|
||||
trainParams,
|
||||
numCols,
|
||||
|
@ -465,7 +475,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
* @param dataframe The dataset to train on.
|
||||
* @return The number of feature columns and initial score classes
|
||||
*/
|
||||
protected def calculateColumnStatistics(dataframe: DataFrame, measures: InstrumentationMeasures): (Int, Int) = {
|
||||
private def calculateColumnStatistics(dataframe: DataFrame, measures: InstrumentationMeasures): (Int, Int) = {
|
||||
measures.markColumnStatisticsStart()
|
||||
// Use the first row to get the column count
|
||||
val firstRow: Row = dataframe.first()
|
||||
|
@ -488,18 +498,18 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
}
|
||||
|
||||
/**
|
||||
* Inner train method for LightGBM learners. Calculates the number of workers,
|
||||
* creates a driver thread, and runs mapPartitions on the dataset.
|
||||
* Calculate row statistics for streaming mode. Gets reference data set and partition counts.
|
||||
*
|
||||
* @param dataframe The dataset to train on.
|
||||
* @param trainingParams The training parameters.
|
||||
* @param numCols The number of feature columns.
|
||||
* @param measures Instrumentation measures.
|
||||
* @return The serialized Dataset reference and an array of partition counts.
|
||||
*/
|
||||
protected def calculateRowStatistics(dataframe: DataFrame,
|
||||
trainingParams: BaseTrainParams,
|
||||
numCols: Int,
|
||||
measures: InstrumentationMeasures): (Array[Row], Array[Long]) = {
|
||||
private def calculateRowStatistics(dataframe: DataFrame,
|
||||
trainingParams: BaseTrainParams,
|
||||
numCols: Int,
|
||||
measures: InstrumentationMeasures): (Array[Byte], Array[Long]) = {
|
||||
measures.markRowStatisticsStart()
|
||||
|
||||
// Get the row counts per partition
|
||||
|
@ -509,43 +519,56 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
val totalNumRows = rowCounts.sum
|
||||
measures.markRowCountsStop()
|
||||
|
||||
// Get sample data using sample() function in Spark
|
||||
// TODO optimize with just a take() in case of user approval
|
||||
measures.markSamplingStart()
|
||||
val sampleCount: Int = getBinSampleCount
|
||||
val seed: Int = getSeedParams.dataRandomSeed.getOrElse(0)
|
||||
val featureColName = getFeaturesCol
|
||||
val fraction = if (sampleCount > totalNumRows) 1.0
|
||||
else Math.min(1.0, (sampleCount.toDouble + 10000)/totalNumRows)
|
||||
val numSamples = Math.min(sampleCount, totalNumRows).toInt
|
||||
val rawSampleData = dataframe.select(dataframe.col(featureColName)).sample(fraction, seed).limit(numSamples)
|
||||
val collectedSampleData = rawSampleData.collect()
|
||||
measures.markSamplingStop()
|
||||
val datasetParams = getDatasetCreationParams(
|
||||
trainingParams.generalParams.categoricalFeatures,
|
||||
trainingParams.executionParams.numThreads)
|
||||
|
||||
// Either get a reference dataset (as bytes) from params, or calculate it
|
||||
val precalculatedDataset = getReferenceDataset
|
||||
val serializedReference = if (precalculatedDataset.nonEmpty) {
|
||||
log.info(s"Using precalculated reference Dataset of length: ${precalculatedDataset.length}")
|
||||
precalculatedDataset
|
||||
} else {
|
||||
// Get sample data rows
|
||||
measures.markSamplingStart()
|
||||
val collectedSampleData = getSampledRows(dataframe, totalNumRows, trainingParams)
|
||||
log.info(s"Using ${collectedSampleData.length} sample rows")
|
||||
measures.markSamplingStop()
|
||||
|
||||
ReferenceDatasetUtils.createReferenceDatasetFromSample(
|
||||
datasetParams,
|
||||
getFeaturesCol,
|
||||
totalNumRows,
|
||||
numCols,
|
||||
collectedSampleData,
|
||||
measures,
|
||||
log)
|
||||
}
|
||||
|
||||
measures.markRowStatisticsStop()
|
||||
(collectedSampleData, rowCounts)
|
||||
(serializedReference, rowCounts)
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a parallel job via map partitions to initialize the native library and network,
|
||||
* translate the data to the LightGBM in-memory representation and train the models.
|
||||
*
|
||||
* @param dataframe The dataset to train on.
|
||||
* @param validationData The dataset to use as validation. (optional)
|
||||
* @param broadcastedSampleData Sample data to use for streaming mode Dataset creation (optional).
|
||||
* @param partitionCounts The count per partition for streaming mode (optional).
|
||||
* @param trainParams Training parameters.
|
||||
* @param numCols Number of columns.
|
||||
* @param numInitValueClasses Number of classes for initial values (used only for multiclass).
|
||||
* @param batchIndex In running in batch training mode, gets the batch number.
|
||||
* @param numTasks Number of tasks/partitions.
|
||||
* @param numTasksPerExecutor Number of tasks per executor.
|
||||
* @param measures Instrumentation measures to populate.
|
||||
* @return The LightGBM Model from the trained LightGBM Booster.
|
||||
*/
|
||||
protected def executeTraining(dataframe: DataFrame,
|
||||
* Run a parallel job via map partitions to initialize the native library and network,
|
||||
* translate the data to the LightGBM in-memory representation and train the models.
|
||||
*
|
||||
* @param dataframe The dataset to train on.
|
||||
* @param validationData The dataset to use as validation. (optional)
|
||||
* @param serializedReferenceDataset The serialized reference dataset (optional).
|
||||
* @param partitionCounts The count per partition for streaming mode (optional).
|
||||
* @param trainParams Training parameters.
|
||||
* @param numCols Number of columns.
|
||||
* @param numInitValueClasses Number of classes for initial values (used only for multiclass).
|
||||
* @param batchIndex In running in batch training mode, gets the batch number.
|
||||
* @param numTasks Number of tasks/partitions.
|
||||
* @param numTasksPerExecutor Number of tasks per executor.
|
||||
* @param measures Instrumentation measures to populate.
|
||||
* @return The LightGBM Model from the trained LightGBM Booster.
|
||||
*/
|
||||
private def executeTraining(dataframe: DataFrame,
|
||||
validationData: Option[Broadcast[Array[Row]]],
|
||||
broadcastedSampleData: Option[Broadcast[Array[Row]]],
|
||||
serializedReferenceDataset: Option[Array[Byte]],
|
||||
partitionCounts: Option[Array[Long]],
|
||||
trainParams: BaseTrainParams,
|
||||
numCols: Int,
|
||||
|
@ -561,9 +584,10 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
getUseBarrierExecutionMode)
|
||||
val ctx = getTrainingContext(dataframe,
|
||||
validationData,
|
||||
broadcastedSampleData,
|
||||
serializedReferenceDataset,
|
||||
partitionCounts,
|
||||
trainParams,
|
||||
getSlotNamesWithMetadata(dataframe.schema(getFeaturesCol)),
|
||||
numCols,
|
||||
numInitValueClasses,
|
||||
batchIndex,
|
||||
|
@ -581,7 +605,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
model
|
||||
}
|
||||
|
||||
protected def executePartitionTasks(ctx: TrainingContext,
|
||||
private def executePartitionTasks(ctx: TrainingContext,
|
||||
dataframe: DataFrame,
|
||||
measures: InstrumentationMeasures): LightGBMBooster = {
|
||||
// Create the object that will manage the mapPartitions function
|
||||
|
@ -608,7 +632,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
*
|
||||
* @param dataframe The dataset to train on.
|
||||
* @param validationData The dataset to use as validation. (optional)
|
||||
* @param broadcastedSampleData Sample data to use for streaming mode Dataset creation (optional).
|
||||
* @param serializedReferenceDataset The serialized reference Dataset. (optional).
|
||||
* @param partitionCounts The count per partition for streaming mode (optional).
|
||||
* @param trainParams Training parameters.
|
||||
* @param numCols Number of columns.
|
||||
|
@ -618,11 +642,12 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
* @param networkManager The network manager.
|
||||
* @return The context of the training session.
|
||||
*/
|
||||
protected def getTrainingContext(dataframe: DataFrame,
|
||||
private def getTrainingContext(dataframe: DataFrame,
|
||||
validationData: Option[Broadcast[Array[Row]]],
|
||||
broadcastedSampleData: Option[Broadcast[Array[Row]]],
|
||||
serializedReferenceDataset: Option[Array[Byte]],
|
||||
partitionCounts: Option[Array[Long]],
|
||||
trainParams: BaseTrainParams,
|
||||
featureNames: Option[Array[String]],
|
||||
numCols: Int,
|
||||
numInitValueClasses: Int,
|
||||
batchIndex: Int,
|
||||
|
@ -647,10 +672,10 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
networkParams,
|
||||
getColumnParams,
|
||||
datasetParams,
|
||||
getSlotNamesWithMetadata(dataframe.schema(getFeaturesCol)),
|
||||
featureNames,
|
||||
numTasksPerExecutor,
|
||||
validationData,
|
||||
broadcastedSampleData,
|
||||
serializedReferenceDataset,
|
||||
partitionCounts)
|
||||
}
|
||||
|
||||
|
@ -687,4 +712,39 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
|
|||
* @return The preprocessed data.
|
||||
*/
|
||||
protected def preprocessData(df: DataFrame): DataFrame = df
|
||||
|
||||
/**
|
||||
* Creates an array of Rows to use as sample data.
|
||||
*
|
||||
* @param dataframe The dataset to train on.
|
||||
* @param totalNumRows The total number of rows in the dataset.
|
||||
* @param trainingParams The training parameters.
|
||||
* @return The serialized Dataset reference and an array of partition counts.
|
||||
*/
|
||||
private def getSampledRows(dataframe: DataFrame,
|
||||
totalNumRows: Long,
|
||||
trainingParams: BaseTrainParams): Array[Row] = {
|
||||
val sampleCount: Int = getBinSampleCount
|
||||
val seed: Int = getSeedParams.dataRandomSeed.getOrElse(0)
|
||||
val featureColName = getFeaturesCol
|
||||
val fraction = if (sampleCount > totalNumRows) 1.0
|
||||
else Math.min(1.0, (sampleCount.toDouble + 10000) / totalNumRows)
|
||||
val numSamples = Math.min(sampleCount, totalNumRows).toInt
|
||||
val samplingSubsetSize = Math.min(trainingParams.executionParams.samplingSetSize, numSamples)
|
||||
|
||||
val samplingMode = trainingParams.executionParams.samplingMode
|
||||
log.info(s"Using sampling mode: $samplingMode (if subset, size is $samplingSubsetSize)")
|
||||
samplingMode match {
|
||||
case LightGBMConstants.SubsetSamplingModeGlobal =>
|
||||
// sample randomly from all rows (expensive for large data sets)
|
||||
dataframe.select(dataframe.col(featureColName)).sample(fraction, seed).limit(numSamples).collect()
|
||||
case LightGBMConstants.SubsetSamplingModeSubset =>
|
||||
// sample randomly from first 'samplingSetSize' rows (optimization to save time on large data sets)
|
||||
dataframe.select(dataframe.col(featureColName)).limit(samplingSubsetSize).sample(fraction, seed).collect()
|
||||
case LightGBMConstants.SubsetSamplingModeFixed =>
|
||||
// just take first 'N' rows. Quick but assumes data already randomized and representative.
|
||||
dataframe.select(dataframe.col(featureColName)).limit(numSamples).collect()
|
||||
case _ => throw new NotImplementedError(s"Unknown sampling mode: $samplingMode")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,12 +31,21 @@ object LightGBMConstants {
|
|||
/** Multiclass classification objective
|
||||
*/
|
||||
val MulticlassObjective: String = "multiclass"
|
||||
/** Streaming execution mode.
|
||||
/** Streaming data transfer mode.
|
||||
*/
|
||||
val StreamingExecutionMode: String = "streaming"
|
||||
/** Bulk execution mode.
|
||||
val StreamingDataTransferMode: String = "streaming"
|
||||
/** Bulk data transfer mode.
|
||||
*/
|
||||
val BulkExecutionMode: String = "bulk"
|
||||
val BulkDataTransferMode: String = "bulk"
|
||||
/** Sampling mode - random n within all global data.
|
||||
*/
|
||||
val SubsetSamplingModeGlobal: String = "global"
|
||||
/** Sampling mode - random n within subset N.
|
||||
*/
|
||||
val SubsetSamplingModeSubset: String = "subset"
|
||||
/** Sampling mode take first n rows.
|
||||
*/
|
||||
val SubsetSamplingModeFixed: String = "fixed"
|
||||
/** Enabled task, used to indicate task that creates lightgbm dataset and runs training.
|
||||
*/
|
||||
val EnabledTask: String = "enabledTask"
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm
|
||||
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.dataset.{LightGBMDataset, SampledData}
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.dataset.{LightGBMDataset, ReferenceDatasetUtils}
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.swig._
|
||||
import com.microsoft.ml.lightgbm._
|
||||
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
|
||||
|
@ -113,9 +113,9 @@ class StreamingPartitionTask extends BasePartitionTask {
|
|||
|
||||
protected def preparePartitionDataInternal(ctx: PartitionTaskContext,
|
||||
inputRows: Iterator[Row]): PartitionDataState = {
|
||||
// If this is the task that will execute training, first create the empty Dataset from sampled data
|
||||
// If this is the task that will execute training, first create the empty Dataset from context
|
||||
if (ctx.shouldExecuteTraining) {
|
||||
ctx.sharedState.datasetState.streamingDataset = Option(createSharedExecutorDataset(ctx))
|
||||
ctx.sharedState.datasetState.streamingDataset = Option(ReferenceDatasetUtils.getInitializedReferenceDataset(ctx))
|
||||
ctx.sharedState.helperStartSignal.countDown()
|
||||
} else {
|
||||
// This must be a task that just loads data and exits, so wait for the shared Dataset to be created
|
||||
|
@ -351,57 +351,6 @@ class StreamingPartitionTask extends BasePartitionTask {
|
|||
}
|
||||
}
|
||||
|
||||
private def createSharedExecutorDataset(ctx: PartitionTaskContext): LightGBMDataset = {
|
||||
// The sample data is broadcast from Spark, so retrieve it
|
||||
ctx.measures.markSamplingStart()
|
||||
val numRows = ctx.executorRowCount
|
||||
val sampledRowData = ctx.trainingCtx.broadcastedSampleData.get.value
|
||||
|
||||
// create properly formatted sampled data
|
||||
log.info(s"Loading sample data from broadcast with ${sampledRowData.length} samples")
|
||||
val datasetVoidPtr = lightgbmlib.voidpp_handle()
|
||||
val sampledData: SampledData = new SampledData(sampledRowData.length, ctx.trainingCtx.numCols)
|
||||
try {
|
||||
sampledRowData.zipWithIndex.foreach(rowAndIndex =>
|
||||
sampledData.pushRow(rowAndIndex._1, rowAndIndex._2, ctx.trainingCtx.columnParams.featuresColumn))
|
||||
ctx.measures.markSamplingStop()
|
||||
|
||||
// Convert byte array to native memory
|
||||
log.info(s"Creating empty training dataset with $numRows rows, config:${ctx.trainingCtx.datasetParams}" +
|
||||
s" for ${ctx.executorPartitionCount} threads")
|
||||
// Generate the dataset for features
|
||||
val datasetVoidPtr = lightgbmlib.voidpp_handle()
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromSampledColumn(
|
||||
sampledData.getSampleData,
|
||||
sampledData.getSampleIndices,
|
||||
ctx.trainingCtx.numCols,
|
||||
sampledData.getRowCounts,
|
||||
sampledData.numRows,
|
||||
numRows,
|
||||
ctx.totalRowCount,
|
||||
ctx.trainingCtx.datasetParams,
|
||||
datasetVoidPtr), "Dataset create")
|
||||
|
||||
val datasetPtr: SWIGTYPE_p_void = lightgbmlib.voidpp_value(datasetVoidPtr)
|
||||
val maxOmpThreads = ctx.trainingCtx.trainingParams.executionParams.maxStreamingOMPThreads
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetInitStreaming(datasetPtr,
|
||||
ctx.trainingCtx.hasWeightsAsInt,
|
||||
ctx.trainingCtx.hasInitialScoresAsInt,
|
||||
ctx.trainingCtx.hasGroupsAsInt,
|
||||
ctx.trainingCtx.trainingParams.getNumClass,
|
||||
ctx.executorPartitionCount,
|
||||
maxOmpThreads),
|
||||
"LGBM_DatasetInitStreaming")
|
||||
|
||||
val dataset = new LightGBMDataset(datasetPtr)
|
||||
dataset.setFeatureNames(ctx.trainingCtx.featureNames, ctx.trainingCtx.numCols)
|
||||
dataset
|
||||
} finally {
|
||||
sampledData.delete()
|
||||
lightgbmlib.delete_voidpp(datasetVoidPtr)
|
||||
}
|
||||
}
|
||||
|
||||
private def createSharedValidationDataset(ctx: PartitionTaskContext, rowCount: Int): LightGBMDataset = {
|
||||
val pointer = lightgbmlib.voidpp_handle()
|
||||
val reference = ctx.sharedState.datasetState.streamingDataset.get.datasetPtr
|
||||
|
@ -417,6 +366,5 @@ class StreamingPartitionTask extends BasePartitionTask {
|
|||
lightgbmlib.delete_voidpp(pointer)
|
||||
val dataset = new LightGBMDataset(datasetPtr)
|
||||
dataset.setFeatureNames(ctx.trainingCtx.featureNames, ctx.trainingCtx.numCols)
|
||||
dataset
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,15 +39,15 @@ case class TrainingContext(batchIndex: Int,
|
|||
featureNames: Option[Array[String]],
|
||||
numTasksPerExecutor: Int,
|
||||
validationData: Option[Broadcast[Array[Row]]],
|
||||
broadcastedSampleData: Option[Broadcast[Array[Row]]],
|
||||
serializedReferenceDataset: Option[Array[Byte]],
|
||||
partitionCounts: Option[Array[Long]]) extends Serializable {
|
||||
val isProvideTrainingMetric: Boolean = { trainingParams.isProvideTrainingMetric.getOrElse(false) }
|
||||
val improvementTolerance: Double = { trainingParams.generalParams.improvementTolerance }
|
||||
val earlyStoppingRound: Int = { trainingParams.generalParams.earlyStoppingRound }
|
||||
val microBatchSize: Int = { trainingParams.executionParams.microBatchSize }
|
||||
|
||||
val isStreaming: Boolean = trainingParams.executionParams.executionMode == LightGBMConstants.StreamingExecutionMode
|
||||
val isBulk: Boolean = trainingParams.executionParams.executionMode == LightGBMConstants.BulkExecutionMode
|
||||
val isStreaming = trainingParams.executionParams.dataTransferMode == LightGBMConstants.StreamingDataTransferMode
|
||||
val isBulk = trainingParams.executionParams.dataTransferMode == LightGBMConstants.BulkDataTransferMode
|
||||
|
||||
val useSingleDatasetMode: Boolean = trainingParams.executionParams.useSingleDatasetMode || isStreaming
|
||||
|
||||
|
|
|
@ -175,14 +175,15 @@ class LightGBMDataset(val datasetPtr: SWIGTYPE_p_void) extends AutoCloseable {
|
|||
addIntField(groupCardinality, "group", groupCardinality.length)
|
||||
}
|
||||
|
||||
def setFeatureNames(featureNamesOpt: Option[Array[String]], numCols: Int): Unit = {
|
||||
def setFeatureNames(featureNamesOpt: Option[Array[String]], numCols: Int): LightGBMDataset = {
|
||||
// Add in slot names if they exist
|
||||
featureNamesOpt.foreach { featureNamesVal =>
|
||||
if (featureNamesVal.nonEmpty) {
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(datasetPtr, featureNamesVal, numCols),
|
||||
featureNamesOpt.foreach { featureNamesArray =>
|
||||
if (featureNamesArray.nonEmpty) {
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSetFeatureNames(datasetPtr, featureNamesArray, numCols),
|
||||
"Dataset set feature names")
|
||||
}
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.dataset
|
||||
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.swig.SwigUtils
|
||||
import com.microsoft.azure.synapse.ml.lightgbm._
|
||||
import com.microsoft.ml.lightgbm._
|
||||
import org.apache.spark.sql._
|
||||
import org.slf4j.Logger
|
||||
|
||||
|
||||
object ReferenceDatasetUtils {
|
||||
def createReferenceDatasetFromSample(datasetParams: String,
|
||||
featuresCol: String,
|
||||
numRows: Long,
|
||||
numCols: Int,
|
||||
sampledRowData: Array[Row],
|
||||
measures: InstrumentationMeasures,
|
||||
log: Logger): Array[Byte] = {
|
||||
log.info(s"Creating reference training dataset with ${sampledRowData.length} samples and config: $datasetParams")
|
||||
|
||||
// Pre-create allocated native pointers so it's easy to clean them up
|
||||
val datasetVoidPtr = lightgbmlib.voidpp_handle()
|
||||
val lenPtr = lightgbmlib.new_intp()
|
||||
val bufferHandlePtr = lightgbmlib.voidpp_handle()
|
||||
|
||||
val sampledData = SampledData(sampledRowData.length, numCols)
|
||||
try {
|
||||
// create properly formatted sampled data
|
||||
measures.markSamplingStart()
|
||||
sampledRowData.zipWithIndex.foreach({case (row, index) => sampledData.pushRow(row, index, featuresCol)})
|
||||
measures.markSamplingStop()
|
||||
|
||||
// Create dataset from samples
|
||||
// 1. Generate the dataset for features
|
||||
val datasetVoidPtr = lightgbmlib.voidpp_handle()
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromSampledColumn(
|
||||
sampledData.getSampleData,
|
||||
sampledData.getSampleIndices,
|
||||
numCols,
|
||||
sampledData.getRowCounts,
|
||||
sampledData.numRows,
|
||||
1, // Used for allocation and must be > 0, but we don't use this reference set for data collection
|
||||
numRows,
|
||||
datasetParams,
|
||||
datasetVoidPtr), "Dataset create from samples")
|
||||
|
||||
|
||||
// 2. Serialize the raw dataset to a native buffer
|
||||
val datasetHandle: SWIGTYPE_p_void = lightgbmlib.voidpp_value(datasetVoidPtr)
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetSerializeReferenceToBinary(
|
||||
datasetHandle,
|
||||
bufferHandlePtr,
|
||||
lenPtr), "Serialize ref")
|
||||
val bufferLen: Int = lightgbmlib.intp_value(lenPtr)
|
||||
log.info(s"Created serialized reference dataset of length $bufferLen")
|
||||
|
||||
// The dataset is now serialized to a buffer, so we don't need original
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetFree(datasetHandle), "Free Dataset")
|
||||
|
||||
// This will also free the buffer
|
||||
toByteArray(bufferHandlePtr, bufferLen)
|
||||
}
|
||||
finally {
|
||||
sampledData.delete()
|
||||
lightgbmlib.delete_voidpp(datasetVoidPtr)
|
||||
lightgbmlib.delete_voidpp(bufferHandlePtr)
|
||||
lightgbmlib.delete_intp(lenPtr)
|
||||
}
|
||||
}
|
||||
|
||||
def getInitializedReferenceDataset(ctx: PartitionTaskContext): LightGBMDataset = {
|
||||
// The definition is broadcast from Spark, so retrieve it
|
||||
val serializedDataset: Array[Byte] = ctx.trainingCtx.serializedReferenceDataset.get
|
||||
|
||||
// Convert byte array to actual dataset
|
||||
val count = ctx.executorRowCount
|
||||
val datasetParams = ctx.trainingCtx.datasetParams
|
||||
val lightGBMDataset = deserializeReferenceDataset(
|
||||
serializedDataset,
|
||||
count,
|
||||
datasetParams)
|
||||
|
||||
// Initialize the dataset for streaming (allocates arrays mostly)
|
||||
val maxOmpThreads = ctx.trainingParams.executionParams.maxStreamingOMPThreads
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetInitStreaming(lightGBMDataset.datasetPtr,
|
||||
ctx.trainingCtx.hasWeightsAsInt,
|
||||
ctx.trainingCtx.hasInitialScoresAsInt,
|
||||
ctx.trainingCtx.hasGroupsAsInt,
|
||||
ctx.trainingParams.getNumClass,
|
||||
ctx.executorPartitionCount,
|
||||
maxOmpThreads),
|
||||
"LGBM_DatasetInitStreaming")
|
||||
|
||||
lightGBMDataset.setFeatureNames(ctx.trainingCtx.featureNames, ctx.trainingCtx.numCols)
|
||||
}
|
||||
|
||||
private def toByteArray(buffer: SWIGTYPE_p_p_void, bufferLen: Int): Array[Byte] = {
|
||||
val byteArray = new Array[Byte](bufferLen)
|
||||
val valPtr = lightgbmlib.new_bytep()
|
||||
val bufferHandle = lightgbmlib.voidpp_value(buffer)
|
||||
|
||||
try
|
||||
{
|
||||
(0 until bufferLen).foreach { i =>
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_ByteBufferGetAt(bufferHandle, i, valPtr), "Buffer get-at")
|
||||
byteArray(i) = lightgbmlib.bytep_value(valPtr).toByte
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
// We assume once converted to byte array we should clean up the native memory and buffer
|
||||
lightgbmlib.delete_bytep(valPtr)
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_ByteBufferFree(bufferHandle), "Buffer free")
|
||||
}
|
||||
|
||||
byteArray
|
||||
}
|
||||
|
||||
private def deserializeReferenceDataset(serializedDataset: Array[Byte],
|
||||
rowCount: Int,
|
||||
datasetParams: String): LightGBMDataset = {
|
||||
// Convert byte array to native memory
|
||||
val datasetVoidPtr = lightgbmlib.voidpp_handle()
|
||||
val nativeByteArray = SwigUtils.byteArrayToNative(serializedDataset)
|
||||
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromSerializedReference( //scalastyle:ignore token
|
||||
lightgbmlib.byte_to_voidp_ptr(nativeByteArray),
|
||||
serializedDataset.length,
|
||||
rowCount,
|
||||
0, // Always zero since we will be using InitStreaming to do allocation
|
||||
datasetParams,
|
||||
datasetVoidPtr), "Dataset create from reference")
|
||||
|
||||
val datasetPtr: SWIGTYPE_p_void = lightgbmlib.voidpp_value(datasetVoidPtr)
|
||||
lightgbmlib.delete_voidpp(datasetVoidPtr)
|
||||
new LightGBMDataset(datasetPtr)
|
||||
}
|
||||
}
|
|
@ -23,7 +23,7 @@ import com.microsoft.ml.lightgbm._
|
|||
* .
|
||||
* Note: sample data row count is not expected to exceed max(Int), so we index with Ints.
|
||||
*/
|
||||
class SampledData(val numRows: Int, val numCols: Int) {
|
||||
case class SampledData(numRows: Int, numCols: Int) {
|
||||
|
||||
// Allocate full arrays for each feature column, but we will push only non-zero values and
|
||||
// keep track of actual counts in rowCounts array
|
||||
|
@ -67,7 +67,7 @@ class SampledData(val numRows: Int, val numCols: Int) {
|
|||
pushRowElementIfNotZero(rowData.indices(i), rowData.values(i), index))
|
||||
}
|
||||
|
||||
def pushRowElementIfNotZero(col: Int, value: Double, index: Int): Unit = {
|
||||
private def pushRowElementIfNotZero(col: Int, value: Double, index: Int): Unit = {
|
||||
if (value != 0.0) {
|
||||
val nextIndex = rowCounts.getItem(col)
|
||||
sampleData.pushElement(col, nextIndex, value)
|
||||
|
|
|
@ -173,7 +173,9 @@ case class DartModeParams(dropRate: Double,
|
|||
* @param matrixType Advanced parameter to specify whether the native lightgbm matrix
|
||||
* constructed should be sparse or dense.
|
||||
* @param numThreads The number of threads to run the native lightgbm training with on each worker.
|
||||
* @param executionMode How to execute the LightGBM training.
|
||||
* @param dataTransferMode How to transfer data to LightGBM to begin the processing.
|
||||
* @param samplingMode How to sample data.
|
||||
* @param samplingSetSize The size of the subset if sampling only a subset.
|
||||
* @param microBatchSize The number of elements in a streaming micro-batch.
|
||||
* @param useSingleDatasetMode Whether to create only 1 LightGBM Dataset on each worker.
|
||||
* @param maxStreamingOMPThreads Maximum number of streaming mode OpenMP threads per Spark Task thread.
|
||||
|
@ -181,7 +183,9 @@ case class DartModeParams(dropRate: Double,
|
|||
case class ExecutionParams(chunkSize: Int,
|
||||
matrixType: String,
|
||||
numThreads: Int,
|
||||
executionMode: String,
|
||||
dataTransferMode: String,
|
||||
samplingMode: String,
|
||||
samplingSetSize: Int,
|
||||
microBatchSize: Int,
|
||||
useSingleDatasetMode: Boolean,
|
||||
maxStreamingOMPThreads: Int) extends ParamGroup {
|
||||
|
|
|
@ -7,6 +7,7 @@ import com.microsoft.azure.synapse.ml.codegen.Wrappable
|
|||
import com.microsoft.azure.synapse.ml.core.contracts.{HasInitScoreCol, HasValidationIndicatorCol, HasWeightCol}
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMDelegate}
|
||||
import com.microsoft.azure.synapse.ml.param.ByteArrayParam
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.util.DefaultParamsWritable
|
||||
|
||||
|
@ -58,12 +59,42 @@ trait LightGBMExecutionParams extends Wrappable {
|
|||
def getUseBarrierExecutionMode: Boolean = $(useBarrierExecutionMode)
|
||||
def setUseBarrierExecutionMode(value: Boolean): this.type = set(useBarrierExecutionMode, value)
|
||||
|
||||
val samplingMode = new Param[String](this, "samplingMode",
|
||||
"Data sampling for streaming mode. Sampled data is used to define bins. " +
|
||||
"'global': sample from all data, 'subset': sample from first N rows, or 'fixed': Take first N rows as sample." +
|
||||
"Values can be global, subset, or fixed. Default is subset.")
|
||||
setDefault(samplingMode -> LightGBMConstants.SubsetSamplingModeSubset)
|
||||
def getSamplingMode: String = $(samplingMode)
|
||||
def setSamplingMode(value: String): this.type = set(samplingMode, value)
|
||||
|
||||
val samplingSubsetSize = new IntParam(this, "samplingSubsetSize",
|
||||
"Specify subset size N for the sampling mode 'subset'. 'binSampleCount' rows will be chosen from " +
|
||||
"the first N values of the dataset. Subset can be used when rows are expected to be random and data is huge.")
|
||||
setDefault(samplingSubsetSize -> 1000000)
|
||||
def getSamplingSubsetSize: Int = $(samplingSubsetSize)
|
||||
def setSamplingSubsetSize(value: Int): this.type = set(samplingSubsetSize, value)
|
||||
|
||||
val referenceDataset: ByteArrayParam = new ByteArrayParam(
|
||||
this,
|
||||
"referenceDataset",
|
||||
"The reference Dataset that was used for the fit. If using samplingMode=custom, this must be set before fit()."
|
||||
)
|
||||
setDefault(referenceDataset -> Array.empty[Byte])
|
||||
def getReferenceDataset: Array[Byte] = $(referenceDataset)
|
||||
def setReferenceDataset(value: Array[Byte]): this.type = set(referenceDataset, value)
|
||||
|
||||
@deprecated("Please use 'dataTransferMode'", since = "0.11.1")
|
||||
val executionMode = new Param[String](this, "executionMode",
|
||||
"Specify how LightGBM is executed. " +
|
||||
"Values can be streaming, bulk. Default is bulk.")
|
||||
setDefault(executionMode -> LightGBMConstants.BulkExecutionMode)
|
||||
def getExecutionMode: String = $(executionMode)
|
||||
def setExecutionMode(value: String): this.type = set(executionMode, value)
|
||||
"Deprecated. Please use dataTransferMode.")
|
||||
@deprecated("Please use 'setDataTransferMode'", since = "0.11.1")
|
||||
def setExecutionMode(value: String): this.type = set(dataTransferMode, value)
|
||||
|
||||
val dataTransferMode = new Param[String](this, "dataTransferMode",
|
||||
"Specify how SynapseML transfers data from Spark to LightGBM. " +
|
||||
"Values can be streaming, bulk. Default is bulk, which is the legacy mode.")
|
||||
setDefault(dataTransferMode -> LightGBMConstants.BulkDataTransferMode)
|
||||
def getDataTransferMode: String = $(dataTransferMode)
|
||||
def setDataTransferMode(value: String): this.type = set(dataTransferMode, value)
|
||||
|
||||
val microBatchSize = new IntParam(this, "microBatchSize",
|
||||
"Specify how many elements are sent in a streaming micro-batch.")
|
||||
|
|
|
@ -22,8 +22,19 @@ object SwigUtils extends Serializable {
|
|||
*/
|
||||
def floatArrayToNative(array: Array[Float]): SWIGTYPE_p_float = {
|
||||
val colArray = lightgbmlib.new_floatArray(array.length)
|
||||
array.zipWithIndex.foreach(ri =>
|
||||
lightgbmlib.floatArray_setitem(colArray, ri._2.toLong, ri._1))
|
||||
array.zipWithIndex.foreach { case (value, index) => lightgbmlib.floatArray_setitem(colArray, index.toLong, value)}
|
||||
colArray
|
||||
}
|
||||
|
||||
/** Converts a Java Byte array to a native C++ array using SWIG.
|
||||
*
|
||||
* @param array The Java Byte Array to convert.
|
||||
* @return The SWIG wrapper around the native array.
|
||||
*/
|
||||
def byteArrayToNative(array: Array[Byte]): SWIGTYPE_p_unsigned_char = {
|
||||
val colArray = lightgbmlib.new_byteArray(array.length)
|
||||
array.zipWithIndex.foreach { case (value, index) =>
|
||||
lightgbmlib.byteArray_setitem(colArray, index.toLong, value.toShort)}
|
||||
colArray
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import org.apache.commons.io.FileUtils
|
|||
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
|
||||
import org.apache.spark.ml.feature.StringIndexer
|
||||
import org.apache.spark.ml.util.MLReadable
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
import java.io.File
|
||||
import java.nio.file.{Files, Path, Paths}
|
||||
|
@ -85,9 +85,6 @@ abstract class LightGBMClassifierTestData extends Benchmarks
|
|||
val transfusionFile: String = "transfusion.csv"
|
||||
|
||||
def baseModel: LightGBMClassifier = {
|
||||
// TODO revert once streaming sparse bug fixed
|
||||
val matrixType = if (executionMode == LightGBMConstants.StreamingExecutionMode) "dense"
|
||||
else "auto"
|
||||
new LightGBMClassifier()
|
||||
.setFeaturesCol(featuresCol)
|
||||
.setRawPredictionCol(rawPredCol)
|
||||
|
@ -98,8 +95,8 @@ abstract class LightGBMClassifierTestData extends Benchmarks
|
|||
.setLabelCol(labelCol)
|
||||
.setLeafPredictionCol(leafPredCol)
|
||||
.setFeaturesShapCol(featuresShapCol)
|
||||
.setExecutionMode(executionMode)
|
||||
.setMatrixType(matrixType)
|
||||
.setDataTransferMode(dataTransferMode)
|
||||
.setMatrixType("auto")
|
||||
}
|
||||
|
||||
def assertBinaryImprovement(sdf1: DataFrame, sdf2: DataFrame): Unit = {
|
||||
|
|
|
@ -114,6 +114,6 @@ trait LightGBMTestUtils extends TestBase {
|
|||
val validationCol = "validation"
|
||||
val seed = 42L
|
||||
|
||||
lazy val executionMode: String = LightGBMConstants.StreamingExecutionMode
|
||||
lazy val executionModeSuffix = ", " + executionMode
|
||||
val dataTransferMode: String = LightGBMConstants.StreamingDataTransferMode
|
||||
val executionModeSuffix = ", " + dataTransferMode
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split1
|
||||
|
||||
// scalastyle:off magic.number
|
||||
/** Tests to validate the functionality of LightGBM module in streaming mode. */
|
||||
class VerifyLightGBMClassifierStreamOnly extends LightGBMClassifierTestData {
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
|
||||
test("Verify LightGBMClassifier handles global sample mode correctly") {
|
||||
val df = loadBinary(breastCancerFile, "Label")
|
||||
val model = baseModel
|
||||
.setBoostingType("gbdt")
|
||||
.setSamplingMode("global")
|
||||
|
||||
val fitModel = model.fit(df)
|
||||
fitModel.transform(df)
|
||||
}
|
||||
|
||||
test("Verify LightGBMClassifier handles fixed sample mode correctly") {
|
||||
val df = loadBinary(breastCancerFile, "Label")
|
||||
val model = baseModel
|
||||
.setBoostingType("gbdt")
|
||||
.setSamplingMode("fixed")
|
||||
|
||||
val fitModel = model.fit(df)
|
||||
fitModel.transform(df)
|
||||
}
|
||||
|
||||
test("Verify LightGBMClassifier handles subset sample mode correctly") {
|
||||
boostingTypes.foreach { boostingType =>
|
||||
val df = loadBinary(breastCancerFile, "Label")
|
||||
val model = baseModel
|
||||
.setBoostingType("gbdt")
|
||||
.setSamplingMode("subset")
|
||||
|
||||
val fitModel = model.fit(df)
|
||||
fitModel.transform(df)
|
||||
}
|
||||
}
|
||||
|
||||
test("Verify LightGBMClassifier can use cached reference dataset") {
|
||||
val baseClassifier = baseModel
|
||||
assert(baseClassifier.getReferenceDataset.isEmpty)
|
||||
|
||||
val model1 = baseClassifier.fit(pimaDF)
|
||||
|
||||
// Assert the generated reference dataset was saved
|
||||
assert(baseClassifier.getReferenceDataset.nonEmpty)
|
||||
|
||||
// Assert we use the same reference data and get same result
|
||||
val model2 = baseModel.fit(pimaDF)
|
||||
assert(model1.getModel.modelStr == model2.getModel.modelStr)
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants
|
|||
|
||||
/** Tests to validate the functionality of LightGBM Ranker module in bulk mode. */
|
||||
class VerifyLightGBMRankerBulk extends VerifyLightGBMRankerStream {
|
||||
override lazy val executionMode: String = LightGBMConstants.BulkExecutionMode
|
||||
override val dataTransferMode: String = LightGBMConstants.BulkDataTransferMode
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import com.microsoft.azure.synapse.ml.lightgbm.LightGBMConstants
|
|||
/** Tests to validate the functionality of LightGBM module in bulk mode.
|
||||
*/
|
||||
class VerifyLightGBMRegressorBulk extends VerifyLightGBMRegressorStream {
|
||||
override lazy val executionMode: String = LightGBMConstants.BulkExecutionMode
|
||||
override val dataTransferMode: String = LightGBMConstants.BulkDataTransferMode
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
}
|
||||
|
|
|
@ -15,24 +15,24 @@ import org.apache.spark.sql.{DataFrame, Row}
|
|||
/** Tests to validate the functionality of LightGBM module in streaming mode.
|
||||
*/
|
||||
class VerifyLightGBMRegressorStream extends LightGBMRegressorTestData {
|
||||
test(verifyLearnerTitleTemplate.format(energyEffFile, executionMode)) {
|
||||
test(verifyLearnerTitleTemplate.format(energyEffFile, dataTransferMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(energyEffFile, "Y1", 0,
|
||||
Some(Seq("X1", "X2", "X3", "X4", "X5", "X6", "X7", "X8", "Y2")))
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(airfoilFile, executionMode)) {
|
||||
test(verifyLearnerTitleTemplate.format(airfoilFile, dataTransferMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(airfoilFile, "Scaled sound pressure level", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(tomsHardwareFile, executionMode)) {
|
||||
test(verifyLearnerTitleTemplate.format(tomsHardwareFile, dataTransferMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(tomsHardwareFile, "Mean Number of display (ND)", -4)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(machineFile, executionMode)) {
|
||||
test(verifyLearnerTitleTemplate.format(machineFile, dataTransferMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(machineFile, "ERP", -2)
|
||||
}
|
||||
/* TODO: Spark doesn't seem to like the column names here because of '.', figure out how to read in the data
|
||||
test(verifyLearnerTitleTemplate.format(slumpFile, executionMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(slumpFile, "Compressive Strength (28-day)(Mpa)", 2)
|
||||
} */
|
||||
test(verifyLearnerTitleTemplate.format(concreteFile, executionMode)) {
|
||||
test(verifyLearnerTitleTemplate.format(concreteFile, dataTransferMode)) {
|
||||
verifyLearnerOnRegressionCsvFile(concreteFile, "Concrete compressive strength(MPa, megapascals)", 0)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split3
|
||||
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMClassificationModel, LightGBMConstants}
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.split1._
|
||||
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
||||
|
||||
// scalastyle:off magic.number
|
||||
/** Tests to validate the functionality of LightGBM module in streaming mode. */
|
||||
class VerifyLightGBMClassifierStreamBasic extends LightGBMClassifierTestData {
|
||||
/* TODO Figure out why abalone has such poor score
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, abaloneFile, executionMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(abaloneFile, "Rings", 2)
|
||||
} */
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, breastTissueFile, dataTransferMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(breastTissueFile, "Class", .07)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, carEvaluationFile, dataTransferMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(carEvaluationFile, "Col7", 2)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, pimaIndianFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(pimaIndianFile, "Diabetes mellitus", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, banknoteFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(banknoteFile, "class", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, taskFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(taskFile, "TaskFailed10", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, breastCancerFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(breastCancerFile, "Label", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, randomForestFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(randomForestFile, "#Malignant", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, transfusionFile, dataTransferMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(transfusionFile, "Donated", 1)
|
||||
}
|
||||
|
||||
test("Verify LightGBMClassifier save booster to " + pimaIndianFile + executionModeSuffix) {
|
||||
verifySaveBooster(
|
||||
fileName = pimaIndianFile,
|
||||
labelColumnName = "Diabetes mellitus",
|
||||
outputFileName = "model.txt",
|
||||
colsToVerify = Array("Diabetes pedigree function", "Age (years)"))
|
||||
}
|
||||
|
||||
test("Compare benchmark results file to generated file" + executionModeSuffix) {
|
||||
verifyBenchmarks()
|
||||
}
|
||||
|
||||
test("Verify LightGBM Classifier can be run with TrainValidationSplit" + executionModeSuffix) {
|
||||
val model = baseModel.setUseBarrierExecutionMode(true)
|
||||
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(model.numLeaves, Array(5, 10))
|
||||
.addGrid(model.numIterations, Array(10, 20))
|
||||
.addGrid(model.lambdaL1, Array(0.1, 0.5))
|
||||
.addGrid(model.lambdaL2, Array(0.1, 0.5))
|
||||
.build()
|
||||
|
||||
val fitModel = new TrainValidationSplit()
|
||||
.setEstimator(model)
|
||||
.setEvaluator(binaryEvaluator)
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
.setTrainRatio(0.8)
|
||||
.setParallelism(2)
|
||||
.fit(pimaDF)
|
||||
|
||||
fitModel.transform(pimaDF)
|
||||
assert(fitModel != null)
|
||||
|
||||
// Validate lambda parameters set on model
|
||||
val modelStr = fitModel.bestModel.asInstanceOf[LightGBMClassificationModel].getModel.modelStr.get
|
||||
assert(modelStr.contains("[lambda_l1: 0.1]") || modelStr.contains("[lambda_l1: 0.5]"))
|
||||
assert(modelStr.contains("[lambda_l2: 0.1]") || modelStr.contains("[lambda_l2: 0.5]"))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split4
|
||||
|
||||
import com.microsoft.azure.synapse.ml.lightgbm._
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.split3.VerifyLightGBMClassifierStreamBasic
|
||||
|
||||
/** Tests to validate the functionality of LightGBM module. */
|
||||
class VerifyLightGBMClassifierBulkBasic extends VerifyLightGBMClassifierStreamBasic {
|
||||
override val dataTransferMode: String = LightGBMConstants.BulkDataTransferMode
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
}
|
|
@ -1,16 +1,16 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split1
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split5
|
||||
|
||||
import com.microsoft.azure.synapse.ml.core.test.benchmarks.DatasetUtils
|
||||
import com.microsoft.azure.synapse.ml.lightgbm._
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.dataset.LightGBMDataset
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.params.FObjTrait
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.split1._
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
@ -18,75 +18,11 @@ import org.apache.spark.sql.{DataFrame, Row}
|
|||
import scala.math.exp
|
||||
|
||||
// scalastyle:off magic.number
|
||||
|
||||
/** Tests to validate the functionality of LightGBM module in streaming mode. */
|
||||
class VerifyLightGBMClassifierStream extends LightGBMClassifierTestData {
|
||||
/* TODO Figure out why abalone has such poor score
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, abaloneFile, executionMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(abaloneFile, "Rings", 2)
|
||||
} */
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, breastTissueFile, executionMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(breastTissueFile, "Class", .07)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.MulticlassObjective, carEvaluationFile, executionMode)) {
|
||||
verifyLearnerOnMulticlassCsvFile(carEvaluationFile, "Col7", 2)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, pimaIndianFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(pimaIndianFile, "Diabetes mellitus", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, banknoteFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(banknoteFile, "class", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, taskFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(taskFile, "TaskFailed10", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, breastCancerFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(breastCancerFile, "Label", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, randomForestFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(randomForestFile, "#Malignant", 1)
|
||||
}
|
||||
test(verifyLearnerTitleTemplate.format(LightGBMConstants.BinaryObjective, transfusionFile, executionMode)) {
|
||||
verifyLearnerOnBinaryCsvFile(transfusionFile, "Donated", 1)
|
||||
}
|
||||
|
||||
test("Verify LightGBMClassifier save booster to " + pimaIndianFile + executionModeSuffix) {
|
||||
verifySaveBooster(
|
||||
fileName = pimaIndianFile,
|
||||
labelColumnName = "Diabetes mellitus",
|
||||
outputFileName = "model.txt",
|
||||
colsToVerify = Array("Diabetes pedigree function", "Age (years)"))
|
||||
}
|
||||
|
||||
test("Compare benchmark results file to generated file" + executionModeSuffix) {
|
||||
verifyBenchmarks()
|
||||
}
|
||||
|
||||
test("Verify LightGBM Classifier can be run with TrainValidationSplit" + executionModeSuffix) {
|
||||
val model = baseModel.setUseBarrierExecutionMode(true)
|
||||
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(model.numLeaves, Array(5, 10))
|
||||
.addGrid(model.numIterations, Array(10, 20))
|
||||
.addGrid(model.lambdaL1, Array(0.1, 0.5))
|
||||
.addGrid(model.lambdaL2, Array(0.1, 0.5))
|
||||
.build()
|
||||
|
||||
val fitModel = new TrainValidationSplit()
|
||||
.setEstimator(model)
|
||||
.setEvaluator(binaryEvaluator)
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
.setTrainRatio(0.8)
|
||||
.setParallelism(2)
|
||||
.fit(pimaDF)
|
||||
|
||||
fitModel.transform(pimaDF)
|
||||
assert(fitModel != null)
|
||||
|
||||
// Validate lambda parameters set on model
|
||||
val modelStr = fitModel.bestModel.asInstanceOf[LightGBMClassificationModel].getModel.modelStr.get
|
||||
assert(modelStr.contains("[lambda_l1: 0.1]") || modelStr.contains("[lambda_l1: 0.5]"))
|
||||
assert(modelStr.contains("[lambda_l2: 0.1]") || modelStr.contains("[lambda_l2: 0.5]"))
|
||||
}
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
|
||||
test("Verify LightGBM Classifier with batch training" + executionModeSuffix) {
|
||||
val batches = Array(0, 2, 10)
|
|
@ -1,13 +1,12 @@
|
|||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License. See LICENSE in project root for information.
|
||||
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split1
|
||||
package com.microsoft.azure.synapse.ml.lightgbm.split6
|
||||
|
||||
import com.microsoft.azure.synapse.ml.lightgbm._
|
||||
import com.microsoft.azure.synapse.ml.lightgbm.split5.VerifyLightGBMClassifierStream
|
||||
|
||||
/** Tests to validate the functionality of LightGBM module. */
|
||||
class VerifyLightGBMClassifierBulk extends VerifyLightGBMClassifierStream {
|
||||
override lazy val executionMode: String = LightGBMConstants.BulkExecutionMode
|
||||
override def ignoreSerializationFuzzing: Boolean = true
|
||||
override def ignoreExperimentFuzzing: Boolean = true
|
||||
override val dataTransferMode: String = LightGBMConstants.BulkDataTransferMode
|
||||
}
|
|
@ -684,6 +684,18 @@ jobs:
|
|||
lightgbm2:
|
||||
PACKAGE: "lightgbm.split2"
|
||||
FLAKY: "true"
|
||||
lightgbm3:
|
||||
PACKAGE: "lightgbm.split3"
|
||||
FLAKY: "true"
|
||||
lightgbm4:
|
||||
PACKAGE: "lightgbm.split4"
|
||||
FLAKY: "true"
|
||||
lightgbm5:
|
||||
PACKAGE: "lightgbm.split5"
|
||||
FLAKY: "true"
|
||||
lightgbm6:
|
||||
PACKAGE: "lightgbm.split6"
|
||||
FLAKY: "true"
|
||||
opencv:
|
||||
PACKAGE: "opencv"
|
||||
recommendation:
|
||||
|
|
Загрузка…
Ссылка в новой задаче