fix: Improve LGBM exception and logging (#2037)

* Improve LGBM exception and logging

* added log
This commit is contained in:
Scott Votaw 2023-08-02 08:55:27 -07:00 коммит произвёл GitHub
Родитель db6386c6d6
Коммит cde68347a4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 9 добавлений и 14 удалений

Просмотреть файл

@ -163,22 +163,25 @@ object NetworkManager {
// and a list of partition ids in this executor.
val lightGbmMachineList = driverInput.readLine()
val partitionsByExecutorStr = driverInput.readLine()
val executorPartitionIds: Array[Int] =
parseExecutorPartitionList(partitionsByExecutorStr, taskStatus.executorId)
log.info(s"task $taskId, partition $partitionId received nodes for network init: '$lightGbmMachineList'")
log.info(s"task $taskId, partition $partitionId received partition topology: '$partitionsByExecutorStr'")
log.info(s"task $taskId, partition $partitionId received nodes for network init: '$lightGbmMachineList'")
val executorPartitionIds: Array[Int] =
parseExecutorPartitionList(partitionsByExecutorStr, taskStatus.executorId, log)
NetworkTopologyInfo(lightGbmMachineList, executorPartitionIds, localListenPort)
}.get
}.get
}
private def parseExecutorPartitionList(partitionsByExecutorStr: String, executorId: String): Array[Int] = {
private def parseExecutorPartitionList(partitionsByExecutorStr: String,
executorId: String,
log: Logger): Array[Int] = {
// extract this executors partition ids as an array, from a string that is formatter like this:
// executor1=partition1,partition2:executor2=partition3,partition4
val partitionsByExecutor = partitionsByExecutorStr.split(":")
val executorListStr = partitionsByExecutor.find(line => line.startsWith(executorId + "="))
if (executorListStr.isEmpty)
throw new Exception(s"Could not find partitions for executor $executorListStr. List: $partitionsByExecutorStr")
throw new Exception(s"Could not find partitions for executor $executorId. List: $partitionsByExecutorStr")
log.info(s"executor $executorId received partitions: '$executorListStr'")
val partitionList = executorListStr.get.split("=")(1)
partitionList.split(",").map(str => str.toInt).sorted
}

Просмотреть файл

@ -174,15 +174,12 @@ class StreamingPartitionTask extends BasePartitionTask {
val partitionRowCount = ctx.trainingCtx.partitionCounts.get(ctx.partitionId).toInt
val partitionRowOffset = ctx.streamingPartitionOffset
val isSparse = ctx.sharedState.isSparse.get
log.info(s"Inserting rows into training Dataset from partition ${ctx.partitionId}, " +
log.debug(s"Inserting rows into training Dataset from partition ${ctx.partitionId}, " +
s"size $partitionRowCount, offset: $partitionRowOffset, sparse: $isSparse, threadId: ${ctx.threadIndex}")
val dataset = ctx.sharedState.datasetState.streamingDataset.get
val stopIndex = partitionRowOffset + partitionRowCount
insertRowsIntoDataset(ctx, dataset, inputRows, partitionRowOffset, stopIndex, ctx.threadIndex)
log.info(s"Part ${ctx.partitionId}: inserted $partitionRowCount partition ${ctx.partitionId} " +
s"rows into shared training dataset at offset $partitionRowOffset")
}
private def insertRowsIntoDataset(ctx: PartitionTaskContext,
@ -213,9 +210,7 @@ class StreamingPartitionTask extends BasePartitionTask {
if (maxBatchSize == 0) 0
else loadOneDenseMicroBatchBuffer(state, inputRows, 0, maxBatchSize)
if (count > 0) {
log.info(s"Part ${state.ctx.partitionId}: Pushing $count dense rows at $startIndex, will stop at $stopIndex")
if (state.hasInitialScores && state.microBatchSize != count && state.numInitScoreClasses > 1) {
log.info(s"Part ${state.ctx.partitionId}: Adjusting $count initial scores")
(1 until state.numInitScoreClasses).foreach { i =>
(0 until count).foreach { j => {
val score = state.initScoreBuffer.getItem(i * state.microBatchSize + j)
@ -253,7 +248,6 @@ class StreamingPartitionTask extends BasePartitionTask {
if (microBatchRowCount > 0) {
// If we have only a partial micro-batch, and we have multi-class initial scores (i.e. numClass > 1),
// we need to re-coalesce the data since it was stored column-wise based on original microBatchSize
log.info(s"Part ${state.ctx.partitionId}: Pushing $microBatchRowCount sparse rows at $startIndex")
if (state.hasInitialScores && state.microBatchSize != microBatchRowCount && state.numInitScoreClasses > 1) {
(1 until state.numInitScoreClasses).foreach { i => // TODO make this shared
(0 until microBatchRowCount).foreach { j => {
@ -279,8 +273,6 @@ class StreamingPartitionTask extends BasePartitionTask {
// might be more rows, so continue with tail recursion at next index
pushSparseMicroBatches(state, inputRows, startIndex + microBatchRowCount, stopIndex)
} else {
log.info(s"LightGBM pushed $startIndex in partition ${state.ctx.partitionId}")
}
}