diff --git a/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java b/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java index 36fd590f..f79fa1f3 100644 --- a/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java +++ b/data/src/main/java/com/microsoft/azure/kusto/data/AadAuthenticationHelper.java @@ -39,13 +39,15 @@ class AadAuthenticationHelper { private AuthenticationResult lastAuthenticationResult; private Lock lastAuthenticationResultLock = new ReentrantLock(); private String applicationClientId; + private Callable tokenProvider; private enum AuthenticationType { AAD_USERNAME_PASSWORD, AAD_APPLICATION_KEY, AAD_DEVICE_LOGIN, AAD_APPLICATION_CERTIFICATE, - AAD_ACCESS_TOKEN + AAD_ACCESS_TOKEN, + AAD_ACCESS_TOKEN_PROVIDER; } AadAuthenticationHelper(@NotNull ConnectionStringBuilder csb) throws URISyntaxException { @@ -67,6 +69,9 @@ class AadAuthenticationHelper { } else if (StringUtils.isNotBlank(csb.getAccessToken())) { authenticationType = AuthenticationType.AAD_ACCESS_TOKEN; accessToken = csb.getAccessToken(); + } else if(csb.getTokenProvider() != null){ + authenticationType = AuthenticationType.AAD_ACCESS_TOKEN_PROVIDER; + tokenProvider = csb.getTokenProvider(); } else { authenticationType = AuthenticationType.AAD_DEVICE_LOGIN; } @@ -81,11 +86,19 @@ class AadAuthenticationHelper { } } - String acquireAccessToken() throws DataServiceException { + String acquireAccessToken() throws DataServiceException, DataClientException { if (authenticationType == AuthenticationType.AAD_ACCESS_TOKEN) { return accessToken; } + if (authenticationType == AuthenticationType.AAD_ACCESS_TOKEN_PROVIDER) { + try { + return tokenProvider.call(); + } catch (Exception e) { + throw new DataClientException(clusterUrl, e.getMessage(), e); + } + } + if (lastAuthenticationResult == null) { acquireToken(); } else if (IsInvalidToken()) { diff --git a/data/src/main/java/com/microsoft/azure/kusto/data/ClientImpl.java b/data/src/main/java/com/microsoft/azure/kusto/data/ClientImpl.java index c93cf5c8..be8bd339 100644 --- a/data/src/main/java/com/microsoft/azure/kusto/data/ClientImpl.java +++ b/data/src/main/java/com/microsoft/azure/kusto/data/ClientImpl.java @@ -139,7 +139,7 @@ public class ClientImpl implements Client, StreamingClient { return Utils.post(clusterEndpoint, null, stream, timeoutMs.intValue() + CLIENT_SERVER_DELTA_IN_MILLISECS, headers, leaveOpen); } - private HashMap initHeaders() throws DataServiceException { + private HashMap initHeaders() throws DataServiceException, DataClientException { HashMap headers = new HashMap<>(); headers.put("x-ms-client-version", clientVersionForTracing); if (applicationNameForTracing != null) { diff --git a/data/src/main/java/com/microsoft/azure/kusto/data/ConnectionStringBuilder.java b/data/src/main/java/com/microsoft/azure/kusto/data/ConnectionStringBuilder.java index d0b2b179..ffc0873b 100644 --- a/data/src/main/java/com/microsoft/azure/kusto/data/ConnectionStringBuilder.java +++ b/data/src/main/java/com/microsoft/azure/kusto/data/ConnectionStringBuilder.java @@ -7,6 +7,7 @@ import org.apache.commons.lang3.StringUtils; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.concurrent.Callable; public class ConnectionStringBuilder { @@ -21,6 +22,7 @@ public class ConnectionStringBuilder { private String clientVersionForTracing; private String applicationNameForTracing; private String accessToken; + private Callable tokenProvider; String getClusterUrl() { return clusterUri; @@ -74,6 +76,10 @@ public class ConnectionStringBuilder { return accessToken; } + public Callable getTokenProvider() { + return tokenProvider; + } + private ConnectionStringBuilder(String resourceUri) { clusterUri = resourceUri; username = null; @@ -84,6 +90,7 @@ public class ConnectionStringBuilder { x509Certificate = null; privateKey = null; accessToken = null; + tokenProvider = null; } public static ConnectionStringBuilder createWithAadUserCredentials(String resourceUri, @@ -183,4 +190,19 @@ public class ConnectionStringBuilder { csb.accessToken = token; return csb; } + + public static ConnectionStringBuilder createWithAadTokenProviderAuthentication(String resourceUri, Callable tokenProviderCallable) { + if (StringUtils.isEmpty(resourceUri)) { + throw new IllegalArgumentException("resourceUri cannot be null or empty"); + } + + if (tokenProviderCallable == null) { + throw new IllegalArgumentException("tokenProviderCallback cannot be null"); + } + + ConnectionStringBuilder csb = new ConnectionStringBuilder(resourceUri); + csb.tokenProvider = tokenProviderCallable; + return csb; + } + } \ No newline at end of file diff --git a/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java b/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java index 097f2527..96ee64c8 100644 --- a/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java +++ b/data/src/test/java/com/microsoft/azure/kusto/data/AadAuthenticationHelperTest.java @@ -5,6 +5,7 @@ package com.microsoft.azure.kusto.data; import com.microsoft.aad.adal4j.AuthenticationResult; import com.microsoft.aad.adal4j.UserInfo; +import com.microsoft.azure.kusto.data.exceptions.DataClientException; import com.microsoft.azure.kusto.data.exceptions.DataServiceException; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; import org.bouncycastle.cert.X509CertificateHolder; @@ -97,7 +98,7 @@ public class AadAuthenticationHelperTest { @Test @DisplayName("validate cached token. Refresh if needed. Call regularly if no refresh token") - void useCachedTokenAndRefreshWhenNeeded() throws InterruptedException, ExecutionException, ServiceUnavailableException, IOException, DataServiceException, URISyntaxException, CertificateException, OperatorCreationException, PKCSException { + void useCachedTokenAndRefreshWhenNeeded() throws InterruptedException, ExecutionException, ServiceUnavailableException, IOException, DataServiceException, URISyntaxException, CertificateException, OperatorCreationException, PKCSException, DataClientException { String certFilePath = Paths.get("src", "test", "resources", "cert.cer").toString(); String privateKeyPath = Paths.get("src", "test", "resources", "key.pem").toString(); diff --git a/data/src/test/java/com/microsoft/azure/kusto/data/ConnectionStringBuilderTest.java b/data/src/test/java/com/microsoft/azure/kusto/data/ConnectionStringBuilderTest.java index 7c8f92f1..3c431b1c 100644 --- a/data/src/test/java/com/microsoft/azure/kusto/data/ConnectionStringBuilderTest.java +++ b/data/src/test/java/com/microsoft/azure/kusto/data/ConnectionStringBuilderTest.java @@ -141,4 +141,21 @@ public class ConnectionStringBuilderTest { Assertions.assertDoesNotThrow( () -> ConnectionStringBuilder .createWithAadAccessTokenAuthentication("resource.uri","token")); } + + @Test + @DisplayName("validate createWithAadTokenProviderAuthentication throws IllegalArgumentException exception when missing or invalid parameters") + void createWithAadTokenProviderAuthentication(){ + + Assertions.assertThrows(IllegalArgumentException.class, + () -> ConnectionStringBuilder + .createWithAadTokenProviderAuthentication(null, () -> "token")); + Assertions.assertThrows(IllegalArgumentException.class, + () -> ConnectionStringBuilder + .createWithAadTokenProviderAuthentication("", () -> "token")); + Assertions.assertThrows(IllegalArgumentException.class, + () -> ConnectionStringBuilder + .createWithAadTokenProviderAuthentication("resource.uri", null)); + Assertions.assertDoesNotThrow( () -> ConnectionStringBuilder + .createWithAadTokenProviderAuthentication("resource.uri", () -> "token")); + } }