pass environment down from hub options
This commit is contained in:
Родитель
4b4cd3f5f0
Коммит
b4be1d2439
100
aad/jwt.go
100
aad/jwt.go
|
@ -21,12 +21,32 @@ const (
|
|||
)
|
||||
|
||||
type (
|
||||
tokenProviderConfiguration struct {
|
||||
tenantID string
|
||||
clientID string
|
||||
clientSecret string
|
||||
certificatePath string
|
||||
certificatePassword string
|
||||
env *azure.Environment
|
||||
}
|
||||
|
||||
// TokenProvider provides cbs.TokenProvider functionality for Azure Active Directory JWT tokens
|
||||
TokenProvider struct {
|
||||
tokenProvider *adal.ServicePrincipalToken
|
||||
}
|
||||
|
||||
// JwtTokenProviderOption provides configuration options for constructing AAD Token Providers
|
||||
JwtTokenProviderOption func(provider *tokenProviderConfiguration) error
|
||||
)
|
||||
|
||||
// JwtTokenProviderWithEnvironment configures the token provider to use a specific Azure Environment
|
||||
func JwtTokenProviderWithEnvironment(env *azure.Environment) JwtTokenProviderOption {
|
||||
return func(config *tokenProviderConfiguration) error {
|
||||
config.env = env
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider builds an Azure Active Directory claims-based security token provider
|
||||
func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
|
||||
return &TokenProvider{
|
||||
|
@ -43,62 +63,78 @@ func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
|
|||
// "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD"
|
||||
//
|
||||
// 3. Managed Service Identity (MSI): attempt to authenticate via MSI
|
||||
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
|
||||
tenantID := os.Getenv("AZURE_TENANT_ID")
|
||||
clientID := os.Getenv("AZURE_CLIENT_ID")
|
||||
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
|
||||
certificatePath := os.Getenv("AZURE_CERTIFICATE_PATH")
|
||||
certificatePassword := os.Getenv("AZURE_CERTIFICATE_PASSWORD")
|
||||
envName := os.Getenv("AZURE_ENVIRONMENT")
|
||||
//
|
||||
//
|
||||
// The Azure Environment used can be specified using the name of the Azure Environment set in "AZURE_ENVIRONMENT" var.
|
||||
func NewProviderFromEnvironment(opts ...JwtTokenProviderOption) (auth.TokenProvider, error) {
|
||||
config := &tokenProviderConfiguration{
|
||||
tenantID: os.Getenv("AZURE_TENANT_ID"),
|
||||
clientID: os.Getenv("AZURE_CLIENT_ID"),
|
||||
clientSecret: os.Getenv("AZURE_CLIENT_SECRET"),
|
||||
certificatePath: os.Getenv("AZURE_CERTIFICATE_PATH"),
|
||||
certificatePassword: os.Getenv("AZURE_CERTIFICATE_PASSWORD"),
|
||||
}
|
||||
|
||||
var env azure.Environment
|
||||
if envName == "" {
|
||||
env = azure.PublicCloud
|
||||
} else {
|
||||
var err error
|
||||
env, err = azure.EnvironmentFromName(envName)
|
||||
for _, opt := range opts {
|
||||
err := opt(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID)
|
||||
if config.env == nil {
|
||||
env, err := azureEnvFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.env = env
|
||||
}
|
||||
|
||||
spToken, err := config.newServicePrincipalToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
}
|
||||
|
||||
func (c *tokenProviderConfiguration) newServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(c.env.ActiveDirectoryEndpoint, c.tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1.Client Credentials
|
||||
if clientSecret != "" {
|
||||
if c.clientSecret != "" {
|
||||
log.Debug("creating a token via a service principal client secret")
|
||||
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, resource)
|
||||
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, c.clientID, c.clientSecret, resource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err)
|
||||
}
|
||||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
return spToken, nil
|
||||
}
|
||||
|
||||
// 2. Client Certificate
|
||||
if certificatePath != "" {
|
||||
if c.certificatePath != "" {
|
||||
log.Debug("creating a token via a service principal client certificate")
|
||||
certData, err := ioutil.ReadFile(certificatePath)
|
||||
certData, err := ioutil.ReadFile(c.certificatePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
|
||||
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", c.certificatePath, err)
|
||||
}
|
||||
certificate, rsaPrivateKey, err := decodePkcs12(certData, certificatePassword)
|
||||
certificate, rsaPrivateKey, err := decodePkcs12(certData, c.certificatePassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
|
||||
}
|
||||
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, resource)
|
||||
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, c.clientID, certificate, rsaPrivateKey, resource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err)
|
||||
}
|
||||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
return spToken, nil
|
||||
}
|
||||
|
||||
// 3. By default return MSI
|
||||
|
@ -114,7 +150,7 @@ func NewProviderFromEnvironment() (auth.TokenProvider, error) {
|
|||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
return spToken, nil
|
||||
}
|
||||
|
||||
// GetToken gets a CBS JWT token
|
||||
|
@ -152,3 +188,19 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private
|
|||
|
||||
return certificate, rsaPrivateKey, nil
|
||||
}
|
||||
|
||||
func azureEnvFromEnvironment() (*azure.Environment, error) {
|
||||
envName := os.Getenv("AZURE_ENVIRONMENT")
|
||||
|
||||
var env azure.Environment
|
||||
if envName == "" {
|
||||
env = azure.PublicCloud
|
||||
} else {
|
||||
var err error
|
||||
env, err = azure.EnvironmentFromName(envName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &env, nil
|
||||
}
|
||||
|
|
10
hub.go
10
hub.go
|
@ -233,6 +233,16 @@ func HubWithUserAgent(userAgent string) HubOption {
|
|||
}
|
||||
}
|
||||
|
||||
// HubWithEnvironment configures the hub to use the specified environment.
|
||||
//
|
||||
// By default, the hub instance will use Azure US Public cloud environment
|
||||
func HubWithEnvironment(env azure.Environment) HubOption {
|
||||
return func(h *hub) error {
|
||||
h.namespace.environment = env
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *hub) appendAgent(userAgent string) error {
|
||||
ua := path.Join(h.userAgent, userAgent)
|
||||
if len(ua) > maxUserAgentLen {
|
||||
|
|
58
hub_test.go
58
hub_test.go
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/sas"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -18,7 +19,7 @@ import (
|
|||
|
||||
func (suite *eventHubSuite) TestSasToken() {
|
||||
tests := map[string]func(*testing.T, Client, []string, string){
|
||||
"TestMultiSendAndReceive": testMultiSendAndReceive,
|
||||
//"TestMultiSendAndReceive": testMultiSendAndReceive,
|
||||
"TestHubRuntimeInformation": testHubRuntimeInformation,
|
||||
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
|
||||
}
|
||||
|
@ -35,11 +36,7 @@ func (suite *eventHubSuite) TestSasToken() {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := suite.newClientWithProvider(t, hubName, provider)
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -65,14 +62,7 @@ func (suite *eventHubSuite) TestPartitionedSender() {
|
|||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
partitionID := (*mgmtHub.PartitionIds)[0]
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider, HubWithPartitionedSender(partitionID))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client := suite.newClient(t, hubName, HubWithPartitionedSender(partitionID))
|
||||
|
||||
testFunc(t, client, partitionID)
|
||||
if err := client.Close(); err != nil {
|
||||
|
@ -136,15 +126,7 @@ func (suite *eventHubSuite) TestMultiPartition() {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := suite.newClient(t, hubName)
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -202,16 +184,8 @@ func (suite *eventHubSuite) TestHubManagement() {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, *mgmtHub.Name)
|
||||
client := suite.newClient(t, hubName)
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -226,6 +200,7 @@ func testHubRuntimeInformation(t *testing.T, client Client, partitionIDs []strin
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Debug(info.PartitionIDs)
|
||||
assert.Equal(t, len(partitionIDs), info.PartitionCount)
|
||||
assert.Equal(t, hubName, info.Path)
|
||||
}
|
||||
|
@ -304,3 +279,20 @@ func BenchmarkReceive(b *testing.B) {
|
|||
wg.Wait()
|
||||
b.StopTimer()
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) newClient(t *testing.T, hubName string, opts ...HubOption) Client {
|
||||
provider, err := aad.NewProviderFromEnvironment(aad.JwtTokenProviderWithEnvironment(&suite.env))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return suite.newClientWithProvider(t, hubName, provider, opts...)
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) newClientWithProvider(t *testing.T, hubName string, provider auth.TokenProvider, opts ...HubOption) Client {
|
||||
opts = append(opts, HubWithEnvironment(suite.env))
|
||||
client, err := NewClient(suite.namespace, hubName, provider, opts...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче