pass environment down from hub options

This commit is contained in:
David Justice 2018-02-22 12:52:55 -08:00
Родитель 4b4cd3f5f0
Коммит b4be1d2439
3 изменённых файлов: 111 добавлений и 57 удалений

Просмотреть файл

@ -21,12 +21,32 @@ const (
)
type (
tokenProviderConfiguration struct {
tenantID string
clientID string
clientSecret string
certificatePath string
certificatePassword string
env *azure.Environment
}
// TokenProvider provides cbs.TokenProvider functionality for Azure Active Directory JWT tokens
TokenProvider struct {
tokenProvider *adal.ServicePrincipalToken
}
// JwtTokenProviderOption provides configuration options for constructing AAD Token Providers
JwtTokenProviderOption func(provider *tokenProviderConfiguration) error
)
// JwtTokenProviderWithEnvironment configures the token provider to use a specific Azure Environment
func JwtTokenProviderWithEnvironment(env *azure.Environment) JwtTokenProviderOption {
return func(config *tokenProviderConfiguration) error {
config.env = env
return nil
}
}
// NewProvider builds an Azure Active Directory claims-based security token provider
func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
return &TokenProvider{
@ -43,62 +63,78 @@ func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
// "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD"
//
// 3. Managed Service Identity (MSI): attempt to authenticate via MSI
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
tenantID := os.Getenv("AZURE_TENANT_ID")
clientID := os.Getenv("AZURE_CLIENT_ID")
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
certificatePath := os.Getenv("AZURE_CERTIFICATE_PATH")
certificatePassword := os.Getenv("AZURE_CERTIFICATE_PASSWORD")
envName := os.Getenv("AZURE_ENVIRONMENT")
//
//
// The Azure Environment used can be specified using the name of the Azure Environment set in "AZURE_ENVIRONMENT" var.
func NewProviderFromEnvironment(opts ...JwtTokenProviderOption) (auth.TokenProvider, error) {
config := &tokenProviderConfiguration{
tenantID: os.Getenv("AZURE_TENANT_ID"),
clientID: os.Getenv("AZURE_CLIENT_ID"),
clientSecret: os.Getenv("AZURE_CLIENT_SECRET"),
certificatePath: os.Getenv("AZURE_CERTIFICATE_PATH"),
certificatePassword: os.Getenv("AZURE_CERTIFICATE_PASSWORD"),
}
var env azure.Environment
if envName == "" {
env = azure.PublicCloud
} else {
var err error
env, err = azure.EnvironmentFromName(envName)
for _, opt := range opts {
err := opt(config)
if err != nil {
return nil, err
}
}
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID)
if config.env == nil {
env, err := azureEnvFromEnvironment()
if err != nil {
return nil, err
}
config.env = env
}
spToken, err := config.newServicePrincipalToken()
if err != nil {
return nil, err
}
return NewProvider(spToken), nil
}
func (c *tokenProviderConfiguration) newServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
oauthConfig, err := adal.NewOAuthConfig(c.env.ActiveDirectoryEndpoint, c.tenantID)
if err != nil {
return nil, err
}
// 1.Client Credentials
if clientSecret != "" {
if c.clientSecret != "" {
log.Debug("creating a token via a service principal client secret")
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, resource)
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, c.clientID, c.clientSecret, resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err)
}
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
return spToken, nil
}
// 2. Client Certificate
if certificatePath != "" {
if c.certificatePath != "" {
log.Debug("creating a token via a service principal client certificate")
certData, err := ioutil.ReadFile(certificatePath)
certData, err := ioutil.ReadFile(c.certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", c.certificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, certificatePassword)
certificate, rsaPrivateKey, err := decodePkcs12(certData, c.certificatePassword)
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, resource)
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, c.clientID, certificate, rsaPrivateKey, resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err)
}
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
return spToken, nil
}
// 3. By default return MSI
@ -114,7 +150,7 @@ func NewProviderFromEnvironment() (auth.TokenProvider, error) {
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
return spToken, nil
}
// GetToken gets a CBS JWT token
@ -152,3 +188,19 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private
return certificate, rsaPrivateKey, nil
}
func azureEnvFromEnvironment() (*azure.Environment, error) {
envName := os.Getenv("AZURE_ENVIRONMENT")
var env azure.Environment
if envName == "" {
env = azure.PublicCloud
} else {
var err error
env, err = azure.EnvironmentFromName(envName)
if err != nil {
return nil, err
}
}
return &env, nil
}

10
hub.go
Просмотреть файл

@ -233,6 +233,16 @@ func HubWithUserAgent(userAgent string) HubOption {
}
}
// HubWithEnvironment configures the hub to use the specified environment.
//
// By default, the hub instance will use Azure US Public cloud environment
func HubWithEnvironment(env azure.Environment) HubOption {
return func(h *hub) error {
h.namespace.environment = env
return nil
}
}
func (h *hub) appendAgent(userAgent string) error {
ua := path.Join(h.userAgent, userAgent)
if len(ua) > maxUserAgentLen {

Просмотреть файл

@ -10,6 +10,7 @@ import (
"time"
"github.com/Azure/azure-event-hubs-go/aad"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/sas"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@ -18,7 +19,7 @@ import (
func (suite *eventHubSuite) TestSasToken() {
tests := map[string]func(*testing.T, Client, []string, string){
"TestMultiSendAndReceive": testMultiSendAndReceive,
//"TestMultiSendAndReceive": testMultiSendAndReceive,
"TestHubRuntimeInformation": testHubRuntimeInformation,
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
}
@ -35,11 +36,7 @@ func (suite *eventHubSuite) TestSasToken() {
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
}
client := suite.newClientWithProvider(t, hubName, provider)
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
if err := client.Close(); err != nil {
t.Fatal(err)
@ -65,14 +62,7 @@ func (suite *eventHubSuite) TestPartitionedSender() {
}
defer suite.deleteEventHub(context.Background(), hubName)
partitionID := (*mgmtHub.PartitionIds)[0]
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider, HubWithPartitionedSender(partitionID))
if err != nil {
t.Fatal(err)
}
client := suite.newClient(t, hubName, HubWithPartitionedSender(partitionID))
testFunc(t, client, partitionID)
if err := client.Close(); err != nil {
@ -136,15 +126,7 @@ func (suite *eventHubSuite) TestMultiPartition() {
t.Fatal(err)
}
defer suite.deleteEventHub(context.Background(), hubName)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
}
client := suite.newClient(t, hubName)
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
if err := client.Close(); err != nil {
t.Fatal(err)
@ -202,16 +184,8 @@ func (suite *eventHubSuite) TestHubManagement() {
t.Fatal(err)
}
defer suite.deleteEventHub(context.Background(), hubName)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
}
testFunc(t, client, *mgmtHub.PartitionIds, *mgmtHub.Name)
client := suite.newClient(t, hubName)
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
if err := client.Close(); err != nil {
t.Fatal(err)
}
@ -226,6 +200,7 @@ func testHubRuntimeInformation(t *testing.T, client Client, partitionIDs []strin
if err != nil {
t.Fatal(err)
}
log.Debug(info.PartitionIDs)
assert.Equal(t, len(partitionIDs), info.PartitionCount)
assert.Equal(t, hubName, info.Path)
}
@ -304,3 +279,20 @@ func BenchmarkReceive(b *testing.B) {
wg.Wait()
b.StopTimer()
}
func (suite *eventHubSuite) newClient(t *testing.T, hubName string, opts ...HubOption) Client {
provider, err := aad.NewProviderFromEnvironment(aad.JwtTokenProviderWithEnvironment(&suite.env))
if err != nil {
t.Fatal(err)
}
return suite.newClientWithProvider(t, hubName, provider, opts...)
}
func (suite *eventHubSuite) newClientWithProvider(t *testing.T, hubName string, provider auth.TokenProvider, opts ...HubOption) Client {
opts = append(opts, HubWithEnvironment(suite.env))
client, err := NewClient(suite.namespace, hubName, provider, opts...)
if err != nil {
t.Fatal(err)
}
return client
}