diff --git a/core/src/main/scala/org/apache/spark/eventhubs/client/CachedEventHubsReceiver.scala b/core/src/main/scala/org/apache/spark/eventhubs/client/CachedEventHubsReceiver.scala index 96f9ca25..4a800892 100644 --- a/core/src/main/scala/org/apache/spark/eventhubs/client/CachedEventHubsReceiver.scala +++ b/core/src/main/scala/org/apache/spark/eventhubs/client/CachedEventHubsReceiver.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.RpcUtils import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{ Await, Awaitable, Future } +import scala.concurrent.{ Await, Awaitable, Future, Promise } private[spark] trait CachedReceiver { private[eventhubs] def receive(ehConf: EventHubsConf, @@ -130,7 +130,13 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf, } private def closeReceiver(): Future[Void] = { - retryJava(receiver.close(), "closing a receiver") + // Closing a PartitionReceiver is not a retryable operation: after the first call, it always + // returns the same CompletableFuture. Therefore, if it fails with a transient + // error, log and continue. + // val dummyResult = Future[Void](null) + val dummyResult = Promise[Void]() + dummyResult success null + retryJava(receiver.close(), "closing a receiver", replaceTransientErrors = dummyResult.future) } private def recreateReceiver(seqNo: SequenceNumber): Unit = { diff --git a/core/src/main/scala/org/apache/spark/eventhubs/utils/RetryUtils.scala b/core/src/main/scala/org/apache/spark/eventhubs/utils/RetryUtils.scala index 12ee4217..13b2e532 100644 --- a/core/src/main/scala/org/apache/spark/eventhubs/utils/RetryUtils.scala +++ b/core/src/main/scala/org/apache/spark/eventhubs/utils/RetryUtils.scala @@ -79,14 +79,16 @@ private[spark] object RetryUtils extends Logging { * @param opName the name of the operation. This is to assist with logging. * @param maxRetry The number of times the operation will be retried. * @param delay The delay (in milliseconds) before the Future is run again. + * @param replaceTransientErrors If not null a transient error returns this Future instead. * @tparam T the result type from the [[CompletableFuture]] * @return the [[Future]] returned by the async operation */ final def retryJava[T](fn: => CompletableFuture[T], opName: String, maxRetry: Int = RetryCount, - delay: Int = 10): Future[T] = { - retryScala(toScala(fn), opName, maxRetry, delay) + delay: Int = 10, + replaceTransientErrors: Future[T] = null): Future[T] = { + retryScala(toScala(fn), opName, maxRetry, delay, replaceTransientErrors) } /** @@ -100,13 +102,15 @@ private[spark] object RetryUtils extends Logging { * @param opName the name of the operation. This is to assist with logging. * @param maxRetry The number of times the operation will be retried. * @param delay The delay (in milliseconds) before the Future is run again. + * @param replaceTransientErrors If not null a transient error returns this Future instead. * @tparam T the result type from the [[Future]] * @return the [[Future]] returned by the async operation */ final def retryScala[T](fn: => Future[T], opName: String, maxRetry: Int = RetryCount, - delay: Int = 10): Future[T] = { + delay: Int = 10, + replaceTransientErrors: Future[T] = null): Future[T] = { def retryHelper(fn: => Future[T], retryCount: Int): Future[T] = { val taskId = EventHubsUtils.getTaskId fn.recoverWith { @@ -115,8 +119,13 @@ private[spark] object RetryUtils extends Logging { logInfo(s"(TID $taskId) failure: $opName") throw eh } - logInfo(s"(TID $taskId) retrying $opName after $delay ms") - after(delay.milliseconds)(retryHelper(fn, retryCount + 1)) + if (replaceTransientErrors != null) { + logInfo(s"(TID $taskId) ignoring transient failure in $opName") + replaceTransientErrors + } else { + logInfo(s"(TID $taskId) retrying $opName after $delay ms") + after(delay.milliseconds)(retryHelper(fn, retryCount + 1)) + } case t: Throwable => t.getCause match { case eh: EventHubException if eh.getIsTransient => @@ -124,8 +133,13 @@ private[spark] object RetryUtils extends Logging { logInfo(s"(TID $taskId) failure: $opName") throw eh } - logInfo(s"(TID $taskId) retrying $opName after $delay ms") - after(delay.milliseconds)(retryHelper(fn, retryCount + 1)) + if (replaceTransientErrors != null) { + logInfo(s"(TID $taskId) ignoring transient failure in $opName") + replaceTransientErrors + } else { + logInfo(s"(TID $taskId) retrying $opName after $delay ms") + after(delay.milliseconds)(retryHelper(fn, retryCount + 1)) + } case _ => logInfo(s"(TID $taskId) failure: $opName") throw t