Add parameters to Aad Auth callback class (#596)
update the aad auth callback class to allow passing a properties bag
This commit is contained in:
Родитель
b41c397038
Коммит
055edd28e3
|
@ -177,6 +177,7 @@ final class EventHubsConf private (private val connectionStr: String)
|
|||
MaxAcceptableBatchReceiveTimeKey,
|
||||
UseAadAuthKey,
|
||||
AadAuthCallbackKey,
|
||||
AadAuthCallbackParamsKey,
|
||||
DynamicPartitionDiscoveryKey
|
||||
).map(_.toLowerCase).toSet
|
||||
|
||||
|
@ -619,9 +620,39 @@ final class EventHubsConf private (private val connectionStr: String)
|
|||
}
|
||||
|
||||
def aadAuthCallback(): Option[AadAuthenticationCallback] = {
|
||||
self.get(AadAuthCallbackKey) map (className => {
|
||||
Class.forName(className).newInstance().asInstanceOf[AadAuthenticationCallback]
|
||||
})
|
||||
val params: Map[String, Object] = self.get(AadAuthCallbackParamsKey) map EventHubsConf
|
||||
.read[Map[String, Object]] getOrElse Map.empty
|
||||
if (params.isEmpty) {
|
||||
self.get(AadAuthCallbackKey) map (className => {
|
||||
Class
|
||||
.forName(className)
|
||||
.getConstructor()
|
||||
.newInstance()
|
||||
.asInstanceOf[AadAuthenticationCallback]
|
||||
})
|
||||
} else {
|
||||
self.get(AadAuthCallbackKey) map (className => {
|
||||
Class
|
||||
.forName(className)
|
||||
.getConstructor(classOf[Map[String, Object]])
|
||||
.newInstance(params)
|
||||
.asInstanceOf[AadAuthenticationCallback]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* set the parameter passed to the aad auth callback class in case it accepts parameters. The aad auth class should
|
||||
* either accepts no parameters or accepts a properties bag(Map[String, Object]). This optional parameter can be used
|
||||
* to pass authentication secrets to the aad auth callback class securely.
|
||||
* More info about this: https://docs.microsoft.com/en-us/azure/event-hubs/authorize-access-azure-active-directory
|
||||
*
|
||||
* @param params The parameters passed to the aad authentication callback class. The parameters should be passed as a
|
||||
* sequence of Strings.
|
||||
* @return the updated [[EventHubsConf]] instance
|
||||
*/
|
||||
def setAadAuthCallbackParams(params: Map[String, Object]): EventHubsConf = {
|
||||
set(AadAuthCallbackParamsKey, EventHubsConf.write[Map[String, Object]](params))
|
||||
}
|
||||
|
||||
// The simulated client (and simulated eventhubs) will be used. These
|
||||
|
@ -683,6 +714,7 @@ object EventHubsConf extends Logging {
|
|||
val MaxAcceptableBatchReceiveTimeKey = "eventhubs.maxAcceptableBatchReceiveTime"
|
||||
val UseAadAuthKey = "eventhubs.useAadAuth"
|
||||
val AadAuthCallbackKey = "eventhubs.aadAuthCallback"
|
||||
val AadAuthCallbackParamsKey = "eventhubs.AadAuthCallbackParams"
|
||||
val DynamicPartitionDiscoveryKey = "eventhubs.DynamicPartitionDiscovery"
|
||||
|
||||
/** Creates an EventHubsConf */
|
||||
|
|
|
@ -62,7 +62,7 @@ private class ClientConnectionPool(val ehConf: EventHubsConf) extends Logging {
|
|||
while (client == null) {
|
||||
if (ehConf.useAadAuth) {
|
||||
val ehClientOption: EventHubClientOptions = new EventHubClientOptions()
|
||||
.setMaximumSilentTime(ehConf.maxSilentTime.getOrElse(DefaultMaxSilentTime))
|
||||
.setMaximumSilentTime(ehConf.maxSilentTime.getOrElse(MinSilentTime))
|
||||
.setOperationTimeout(ehConf.receiverTimeout.getOrElse(DefaultReceiverTimeout))
|
||||
.setRetryPolicy(RetryPolicy.getDefault)
|
||||
client = Await.result(
|
||||
|
|
|
@ -17,15 +17,19 @@
|
|||
|
||||
package org.apache.spark.eventhubs
|
||||
|
||||
import org.apache.spark.eventhubs.utils.{
|
||||
AadAuthenticationCallbackMock,
|
||||
AadAuthenticationCallbackMockWithParams,
|
||||
EventHubsTestUtils,
|
||||
MetricPluginMock,
|
||||
ThrottlingStatusPluginMock
|
||||
}
|
||||
import java.time.Duration
|
||||
import java.util.NoSuchElementException
|
||||
|
||||
import org.apache.spark.eventhubs.utils.{AadAuthenticationCallbackMock, EventHubsTestUtils, MetricPluginMock, ThrottlingStatusPluginMock}
|
||||
import org.json4s.NoTypeHints
|
||||
import org.json4s.jackson.Serialization
|
||||
import org.json4s.jackson.Serialization.{read => sread}
|
||||
import org.json4s.jackson.Serialization.{write => swrite}
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.json4s.jackson.Serialization.{ read => sread }
|
||||
import org.json4s.jackson.Serialization.{ write => swrite }
|
||||
import org.scalatest.{ BeforeAndAfterAll, FunSuite }
|
||||
|
||||
/**
|
||||
* Tests [[EventHubsConf]] for correctness.
|
||||
|
@ -377,14 +381,28 @@ class EventHubsConfSuite extends FunSuite with BeforeAndAfterAll {
|
|||
assert(expectedTime == actualTime)
|
||||
}
|
||||
|
||||
|
||||
test("validate - AadAuthenticationCallback") {
|
||||
val aadAuthCallback = new AadAuthenticationCallbackMock()
|
||||
val eventHubConfig = testUtils.getEventHubsConf()
|
||||
val eventHubConfig = testUtils
|
||||
.getEventHubsConf()
|
||||
.setAadAuthCallback(aadAuthCallback)
|
||||
|
||||
val actualCallback = eventHubConfig.aadAuthCallback()
|
||||
assert(eventHubConfig.useAadAuth)
|
||||
assert(actualCallback.get.isInstanceOf[AadAuthenticationCallbackMock])
|
||||
}
|
||||
|
||||
test("validate - AadAuthenticationCallbackWithParams") {
|
||||
val params: Map[String, String] = Map("authority" -> "passed-tenant-id")
|
||||
val aadAuthCallbackWithParams = new AadAuthenticationCallbackMockWithParams(params)
|
||||
val eventHubConfig = testUtils
|
||||
.getEventHubsConf()
|
||||
.setAadAuthCallback(aadAuthCallbackWithParams)
|
||||
.setAadAuthCallbackParams(params)
|
||||
|
||||
val actualCallback = eventHubConfig.aadAuthCallback()
|
||||
assert(eventHubConfig.useAadAuth)
|
||||
assert(actualCallback.get.isInstanceOf[AadAuthenticationCallbackMockWithParams])
|
||||
assert(actualCallback.get.authority == "passed-tenant-id")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,3 +8,12 @@ class AadAuthenticationCallbackMock extends AadAuthenticationCallback {
|
|||
|
||||
override def authority: String = "Fake-tenant-id"
|
||||
}
|
||||
|
||||
class AadAuthenticationCallbackMockWithParams(params: Map[String, Object])
|
||||
extends AadAuthenticationCallback {
|
||||
override def acquireToken(s: String, s1: String, o: Any): CompletableFuture[String] = {
|
||||
new CompletableFuture[String]()
|
||||
}
|
||||
|
||||
override def authority: String = params("authority").asInstanceOf[String]
|
||||
}
|
||||
|
|
|
@ -2,16 +2,22 @@
|
|||
This guide will show you how you can
|
||||
<a href="https://docs.microsoft.com/en-us/azure/event-hubs/authenticate-application" target="_blank">use AAD authentication to access Eventhubs</a>.
|
||||
|
||||
* [Use Service Principal + Secret to authorize](#use-service-principal-+-secret-to-authorize)
|
||||
* [Use Service Principal + Certificate to authorize](#use-service-principal-+-certificate-to-authorize)
|
||||
* [Use Service Principal with Secret to Authorize](#use-service-principal-with-secret-to-authorize)
|
||||
* [Write Secrets in Callback Class](#write-secret-in-callback-class)
|
||||
* [Pass Secrets to Callback Class](#pass-secret-to-callback-class)
|
||||
* [Use Service Principal with Certificate to Authorize](#use-service-principal-with-certificate-to-authorize)
|
||||
|
||||
|
||||
## Use Service Principal + Secret to authorize
|
||||
First, you need to create a callback class extends from `org.apache.spark.eventhubs.utils.AadAuthenticationCallback`,
|
||||
## Use Service Principal with Secret to Authorize
|
||||
First, you need to create a callback class extends from `org.apache.spark.eventhubs.utils.AadAuthenticationCallback`. There are two options on how the callback class can access the secrets. Either set the secrets directly in the class definition, or pass the secrets in a properties bag of type `Map[String, Object]` to the callback class.
|
||||
Please note that since the connector is using reflection to instantiate the callback class on each executor node, the callback class definition should be packaged in a jar file and be added to your cluster.
|
||||
|
||||
### Write Secret in Callback Class
|
||||
In this case, you set the required secrets in the callback class as shown in the below example:
|
||||
|
||||
```scala
|
||||
import java.util.Collections
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
import com.microsoft.aad.msal4j.{IAuthenticationResult, _}
|
||||
import org.apache.spark.eventhubs.utils.AadAuthenticationCallback
|
||||
|
||||
|
@ -20,14 +26,14 @@ case class AuthBySecretCallBack() extends AadAuthenticationCallback{
|
|||
implicit def toJavaFunction[A, B](f: Function1[A, B]) = new java.util.function.Function[A, B] {
|
||||
override def apply(a: A): B = f(a)
|
||||
}
|
||||
|
||||
override def authority: String = "your-tenant-id"
|
||||
|
||||
val clientId: String = "your-client-id"
|
||||
val clientSecret: String = "your-client-secret"
|
||||
|
||||
override def acquireToken(audience: String, authority: String, state: Any): CompletableFuture[String] = try {
|
||||
var app = ConfidentialClientApplication
|
||||
.builder("clientId", ClientCredentialFactory.createFromSecret(this.clientSecret))
|
||||
.builder(clientId, ClientCredentialFactory.createFromSecret(this.clientSecret))
|
||||
.authority("https://login.microsoftonline.com/" + authority)
|
||||
.build
|
||||
|
||||
|
@ -42,18 +48,77 @@ case class AuthBySecretCallBack() extends AadAuthenticationCallback{
|
|||
}
|
||||
}
|
||||
```
|
||||
and then set the authentication to use AAD auth.
|
||||
|
||||
Now you can use the `setAadAuthCallback` option in `EventHubsConf` to us AAD authentication to connect to your EventHub instance.
|
||||
|
||||
```scala
|
||||
val connectionString = ConnectionStringBuilder()
|
||||
.setAadAuthConnectionString(new URI("your-ehs-endpoint"), "your-ehs-name")
|
||||
.build
|
||||
|
||||
val ehConf = EventHubsConf(connectionString)
|
||||
.setConsumerGroup("consumerGroup")
|
||||
.setAadAuthCallback(AuthBySecretCallBack())
|
||||
```
|
||||
|
||||
|
||||
## Use Service Principal + Certificate to authorize
|
||||
### Pass Secrets to Callback Class
|
||||
Another option is to pass the secrets in a properties bag to the callback class. For instance, if you want to read the secrets from a [secret scope](https://docs.microsoft.com/en-us/azure/databricks/security/secrets/secret-scopes), you can use `dbutils` API to get the secrets on the driver and pass those to the callback class. Note that the callback class only accepts one parameter of type `Map[String, Object]`.
|
||||
Here is an example showing how you can do so:
|
||||
|
||||
```scala
|
||||
import java.util.Collections
|
||||
import java.util.concurrent.CompletableFuture
|
||||
import com.microsoft.aad.msal4j.{IAuthenticationResult, _}
|
||||
import org.apache.spark.eventhubs.utils.AadAuthenticationCallback
|
||||
|
||||
class AuthBySecretCallBackWithParams(params: Map[String, Object]) extends AadAuthenticationCallback{
|
||||
|
||||
implicit def toJavaFunction[A, B](f: Function1[A, B]) = new java.util.function.Function[A, B] {
|
||||
override def apply(a: A): B = f(a)
|
||||
}
|
||||
|
||||
override def authority: String = params("authority").asInstanceOf[String]
|
||||
val clientId: String = params("clientId").asInstanceOf[String]
|
||||
val clientSecret: String = params("clientSecret").asInstanceOf[String]
|
||||
|
||||
override def acquireToken(audience: String, authority: String, state: Any): CompletableFuture[String] = try {
|
||||
var app = ConfidentialClientApplication
|
||||
.builder(clientId, ClientCredentialFactory.createFromSecret(this.clientSecret))
|
||||
.authority("https://login.microsoftonline.com/" + authority)
|
||||
.build
|
||||
|
||||
val parameters = ClientCredentialParameters.builder(Collections.singleton(audience + ".default")).build
|
||||
|
||||
app.acquireToken(parameters).thenApply((result: IAuthenticationResult) => result.accessToken())
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
val failed = new CompletableFuture[String]
|
||||
failed.completeExceptionally(e)
|
||||
failed
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In this case you should use both `setAadAuthCallback` and `setAadAuthCallbackParams` options in `EventHubsConf` to us AAD authentication to connect to your EventHub instance.
|
||||
|
||||
```scala
|
||||
val params: Map[String, String] = Map("authority" -> dbutils.secrets.get(scope = "nykvsecrets", key = "ehaadtesttenantid"),
|
||||
"clientId" -> dbutils.secrets.get(scope = "nykvsecrets", key = "ehaadtestclientid"),
|
||||
"clientSecret" -> dbutils.secrets.get(scope = "nykvsecrets", key = "ehaadtestclientsecret"))
|
||||
|
||||
val connectionString = ConnectionStringBuilder()
|
||||
.setAadAuthConnectionString(new URI("your-ehs-endpoint"), "your-ehs-name")
|
||||
.build
|
||||
|
||||
val ehConf = EventHubsConf(connectionString)
|
||||
.setConsumerGroup("consumerGroup")
|
||||
.setAadAuthCallback(new AuthBySecretCallBackWithParams(params))
|
||||
.setAadAuthCallbackParams(params)
|
||||
```
|
||||
|
||||
|
||||
## Use Service Principal with Certificate to Authorize
|
||||
|
||||
Alternatively, you can use certificate to make your connections.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче