diff --git a/aad/jwt.go b/aad/jwt.go index 6097a61..0ac7afd 100644 --- a/aad/jwt.go +++ b/aad/jwt.go @@ -21,12 +21,32 @@ const ( ) type ( + tokenProviderConfiguration struct { + tenantID string + clientID string + clientSecret string + certificatePath string + certificatePassword string + env *azure.Environment + } + // TokenProvider provides cbs.TokenProvider functionality for Azure Active Directory JWT tokens TokenProvider struct { tokenProvider *adal.ServicePrincipalToken } + + // JwtTokenProviderOption provides configuration options for constructing AAD Token Providers + JwtTokenProviderOption func(provider *tokenProviderConfiguration) error ) +// JwtTokenProviderWithEnvironment configures the token provider to use a specific Azure Environment +func JwtTokenProviderWithEnvironment(env *azure.Environment) JwtTokenProviderOption { + return func(config *tokenProviderConfiguration) error { + config.env = env + return nil + } +} + // NewProvider builds an Azure Active Directory claims-based security token provider func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider { return &TokenProvider{ @@ -43,62 +63,78 @@ func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider { // "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD" // // 3. Managed Service Identity (MSI): attempt to authenticate via MSI -func NewProviderFromEnvironment() (auth.TokenProvider, error) { - tenantID := os.Getenv("AZURE_TENANT_ID") - clientID := os.Getenv("AZURE_CLIENT_ID") - clientSecret := os.Getenv("AZURE_CLIENT_SECRET") - certificatePath := os.Getenv("AZURE_CERTIFICATE_PATH") - certificatePassword := os.Getenv("AZURE_CERTIFICATE_PASSWORD") - envName := os.Getenv("AZURE_ENVIRONMENT") +// +// +// The Azure Environment used can be specified using the name of the Azure Environment set in "AZURE_ENVIRONMENT" var. +func NewProviderFromEnvironment(opts ...JwtTokenProviderOption) (auth.TokenProvider, error) { + config := &tokenProviderConfiguration{ + tenantID: os.Getenv("AZURE_TENANT_ID"), + clientID: os.Getenv("AZURE_CLIENT_ID"), + clientSecret: os.Getenv("AZURE_CLIENT_SECRET"), + certificatePath: os.Getenv("AZURE_CERTIFICATE_PATH"), + certificatePassword: os.Getenv("AZURE_CERTIFICATE_PASSWORD"), + } - var env azure.Environment - if envName == "" { - env = azure.PublicCloud - } else { - var err error - env, err = azure.EnvironmentFromName(envName) + for _, opt := range opts { + err := opt(config) if err != nil { return nil, err } } - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID) + if config.env == nil { + env, err := azureEnvFromEnvironment() + if err != nil { + return nil, err + } + config.env = env + } + + spToken, err := config.newServicePrincipalToken() + if err != nil { + return nil, err + } + return NewProvider(spToken), nil +} + +func (c *tokenProviderConfiguration) newServicePrincipalToken() (*adal.ServicePrincipalToken, error) { + oauthConfig, err := adal.NewOAuthConfig(c.env.ActiveDirectoryEndpoint, c.tenantID) if err != nil { return nil, err } // 1.Client Credentials - if clientSecret != "" { + if c.clientSecret != "" { log.Debug("creating a token via a service principal client secret") - spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, resource) + spToken, err := adal.NewServicePrincipalToken(*oauthConfig, c.clientID, c.clientSecret, resource) if err != nil { return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err) } if err := spToken.Refresh(); err != nil { return nil, fmt.Errorf("failed to refersh token: %v", spToken) } - return NewProvider(spToken), nil + return spToken, nil } // 2. Client Certificate - if certificatePath != "" { + if c.certificatePath != "" { log.Debug("creating a token via a service principal client certificate") - certData, err := ioutil.ReadFile(certificatePath) + certData, err := ioutil.ReadFile(c.certificatePath) if err != nil { - return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err) + return nil, fmt.Errorf("failed to read the certificate file (%s): %v", c.certificatePath, err) } - certificate, rsaPrivateKey, err := decodePkcs12(certData, certificatePassword) + certificate, rsaPrivateKey, err := decodePkcs12(certData, c.certificatePassword) if err != nil { return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err) } - spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, resource) + spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, c.clientID, certificate, rsaPrivateKey, resource) if err != nil { return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err) } if err := spToken.Refresh(); err != nil { return nil, fmt.Errorf("failed to refersh token: %v", spToken) } - return NewProvider(spToken), nil + return spToken, nil } // 3. By default return MSI @@ -114,7 +150,7 @@ func NewProviderFromEnvironment() (auth.TokenProvider, error) { if err := spToken.Refresh(); err != nil { return nil, fmt.Errorf("failed to refersh token: %v", spToken) } - return NewProvider(spToken), nil + return spToken, nil } // GetToken gets a CBS JWT token @@ -152,3 +188,19 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private return certificate, rsaPrivateKey, nil } + +func azureEnvFromEnvironment() (*azure.Environment, error) { + envName := os.Getenv("AZURE_ENVIRONMENT") + + var env azure.Environment + if envName == "" { + env = azure.PublicCloud + } else { + var err error + env, err = azure.EnvironmentFromName(envName) + if err != nil { + return nil, err + } + } + return &env, nil +} diff --git a/hub.go b/hub.go index 4f4bbe8..37c114e 100644 --- a/hub.go +++ b/hub.go @@ -233,6 +233,16 @@ func HubWithUserAgent(userAgent string) HubOption { } } +// HubWithEnvironment configures the hub to use the specified environment. +// +// By default, the hub instance will use Azure US Public cloud environment +func HubWithEnvironment(env azure.Environment) HubOption { + return func(h *hub) error { + h.namespace.environment = env + return nil + } +} + func (h *hub) appendAgent(userAgent string) error { ua := path.Join(h.userAgent, userAgent) if len(ua) > maxUserAgentLen { diff --git a/hub_test.go b/hub_test.go index 8c9efcc..cda62df 100644 --- a/hub_test.go +++ b/hub_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/Azure/azure-event-hubs-go/aad" + "github.com/Azure/azure-event-hubs-go/auth" "github.com/Azure/azure-event-hubs-go/sas" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -18,7 +19,7 @@ import ( func (suite *eventHubSuite) TestSasToken() { tests := map[string]func(*testing.T, Client, []string, string){ - "TestMultiSendAndReceive": testMultiSendAndReceive, + //"TestMultiSendAndReceive": testMultiSendAndReceive, "TestHubRuntimeInformation": testHubRuntimeInformation, "TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation, } @@ -35,11 +36,7 @@ func (suite *eventHubSuite) TestSasToken() { if err != nil { t.Fatal(err) } - client, err := NewClient(suite.namespace, hubName, provider) - if err != nil { - t.Fatal(err) - } - + client := suite.newClientWithProvider(t, hubName, provider) testFunc(t, client, *mgmtHub.PartitionIds, hubName) if err := client.Close(); err != nil { t.Fatal(err) @@ -65,14 +62,7 @@ func (suite *eventHubSuite) TestPartitionedSender() { } defer suite.deleteEventHub(context.Background(), hubName) partitionID := (*mgmtHub.PartitionIds)[0] - provider, err := aad.NewProviderFromEnvironment() - if err != nil { - t.Fatal(err) - } - client, err := NewClient(suite.namespace, hubName, provider, HubWithPartitionedSender(partitionID)) - if err != nil { - t.Fatal(err) - } + client := suite.newClient(t, hubName, HubWithPartitionedSender(partitionID)) testFunc(t, client, partitionID) if err := client.Close(); err != nil { @@ -136,15 +126,7 @@ func (suite *eventHubSuite) TestMultiPartition() { t.Fatal(err) } defer suite.deleteEventHub(context.Background(), hubName) - provider, err := aad.NewProviderFromEnvironment() - if err != nil { - t.Fatal(err) - } - client, err := NewClient(suite.namespace, hubName, provider) - if err != nil { - t.Fatal(err) - } - + client := suite.newClient(t, hubName) testFunc(t, client, *mgmtHub.PartitionIds, hubName) if err := client.Close(); err != nil { t.Fatal(err) @@ -202,16 +184,8 @@ func (suite *eventHubSuite) TestHubManagement() { t.Fatal(err) } defer suite.deleteEventHub(context.Background(), hubName) - provider, err := aad.NewProviderFromEnvironment() - if err != nil { - t.Fatal(err) - } - client, err := NewClient(suite.namespace, hubName, provider) - if err != nil { - t.Fatal(err) - } - - testFunc(t, client, *mgmtHub.PartitionIds, *mgmtHub.Name) + client := suite.newClient(t, hubName) + testFunc(t, client, *mgmtHub.PartitionIds, hubName) if err := client.Close(); err != nil { t.Fatal(err) } @@ -226,6 +200,7 @@ func testHubRuntimeInformation(t *testing.T, client Client, partitionIDs []strin if err != nil { t.Fatal(err) } + log.Debug(info.PartitionIDs) assert.Equal(t, len(partitionIDs), info.PartitionCount) assert.Equal(t, hubName, info.Path) } @@ -304,3 +279,20 @@ func BenchmarkReceive(b *testing.B) { wg.Wait() b.StopTimer() } + +func (suite *eventHubSuite) newClient(t *testing.T, hubName string, opts ...HubOption) Client { + provider, err := aad.NewProviderFromEnvironment(aad.JwtTokenProviderWithEnvironment(&suite.env)) + if err != nil { + t.Fatal(err) + } + return suite.newClientWithProvider(t, hubName, provider, opts...) +} + +func (suite *eventHubSuite) newClientWithProvider(t *testing.T, hubName string, provider auth.TokenProvider, opts ...HubOption) Client { + opts = append(opts, HubWithEnvironment(suite.env)) + client, err := NewClient(suite.namespace, hubName, provider, opts...) + if err != nil { + t.Fatal(err) + } + return client +}