Merge pull request #9 from devigned/cleanup
Clean up and add some environmental help
This commit is contained in:
Коммит
1d3f2d42fa
|
@ -17,6 +17,7 @@
|
|||
"autorest",
|
||||
"autorest/adal",
|
||||
"autorest/azure",
|
||||
"autorest/azure/auth",
|
||||
"autorest/date",
|
||||
"autorest/to",
|
||||
"autorest/validation"
|
||||
|
@ -36,6 +37,12 @@
|
|||
revision = "dbeaa9332f19a944acb5736b4456cfcc02140e29"
|
||||
version = "v3.1.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/dimchansky/utfbom"
|
||||
packages = ["."]
|
||||
revision = "6c6132ff69f0f6c088739067407b5d32c52e1d0f"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/pkg/errors"
|
||||
packages = ["."]
|
||||
|
@ -73,8 +80,12 @@
|
|||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
packages = ["ssh/terminal"]
|
||||
revision = "650f4a345ab4e5b245a3034b110ebc7299e68186"
|
||||
packages = [
|
||||
"pkcs12",
|
||||
"pkcs12/internal/rc2",
|
||||
"ssh/terminal"
|
||||
]
|
||||
revision = "432090b8f568c018896cd8a0fb0345872bbac6ce"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
|
@ -97,6 +108,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "5f18a64a7717270d06573eadd917fc9fb583eb9953b9811f04fb4bc04a0c0a7f"
|
||||
inputs-digest = "b7c45c0ec32d34417f3a7f0cd4ad053b98b3bb6c56d189c9585a2e10c2f49a62"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
2
Makefile
2
Makefile
|
@ -17,7 +17,7 @@ DEP = dep
|
|||
V = 0
|
||||
Q = $(if $(filter 1,$V),,@)
|
||||
M = $(shell printf "\033[34;1m▶\033[0m")
|
||||
TIMEOUT = 100
|
||||
TIMEOUT = 200
|
||||
|
||||
.PHONY: all
|
||||
all: fmt vendor lint vet | $(BASE) ; $(info $(M) building library…) @ ## Build program
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/Azure/azure-event-hubs-go"
|
||||
"fmt"
|
||||
"time"
|
||||
"os"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"log"
|
||||
"context"
|
||||
"pack.ag/amqp"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go"
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,19 +26,21 @@ func main() {
|
|||
exit := make(chan struct{})
|
||||
|
||||
handler := func(ctx context.Context, msg *amqp.Message) error {
|
||||
text := string(msg.Data)
|
||||
text := string(msg.Data[0])
|
||||
if text == "exit\n" {
|
||||
fmt.Println("Someone told me to exit!")
|
||||
exit <- *new(struct{})
|
||||
} else {
|
||||
fmt.Println(string(msg.Data))
|
||||
fmt.Println(string(msg.Data[0]))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
for _, partitionID := range partitions {
|
||||
hub.Receive(partitionID, handler)
|
||||
hub.Receive(ctx, partitionID, handler)
|
||||
}
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-exit:
|
||||
|
@ -57,11 +59,10 @@ func initHub() (eventhub.Client, []string) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
aadToken, err := getEventHubsTokenProvider()
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
provider := aad.NewProvider(aadToken)
|
||||
hub, err := eventhub.NewClient(namespace, HubName, provider)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -77,30 +78,6 @@ func mustGetenv(key string) string {
|
|||
return v
|
||||
}
|
||||
|
||||
func getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
|
||||
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
|
||||
return getTokenProvider("https://eventhubs.azure.net/")
|
||||
}
|
||||
|
||||
func getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, mustGetenv("AZURE_TENANT_ID"))
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, mustGetenv("AZURE_CLIENT_ID"), mustGetenv("AZURE_CLIENT_SECRET"), resourceURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tokenProvider.Refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
|
||||
namespace := mustGetenv("EVENTHUB_NAMESPACE")
|
||||
client := getEventHubMgmtClient()
|
||||
|
@ -126,10 +103,10 @@ func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
|
|||
func getEventHubMgmtClient() *mgmt.EventHubsClient {
|
||||
subID := mustGetenv("AZURE_SUBSCRIPTION_ID")
|
||||
client := mgmt.NewEventHubsClientWithBaseURI(azure.PublicCloud.ResourceManagerEndpoint, subID)
|
||||
armToken, err := getTokenProvider(azure.PublicCloud.ResourceManagerEndpoint)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
|
||||
client.Authorizer = a
|
||||
return &client
|
||||
}
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/Azure/azure-event-hubs-go"
|
||||
"fmt"
|
||||
"os"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"log"
|
||||
"context"
|
||||
"pack.ag/amqp"
|
||||
"bufio"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go"
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -28,10 +29,12 @@ func main() {
|
|||
for {
|
||||
fmt.Print("Enter text: ")
|
||||
text, _ := reader.ReadString('\n')
|
||||
hub.Send(context.Background(), &amqp.Message{Data: []byte(text)})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
hub.Send(ctx, amqp.NewMessage([]byte(text)))
|
||||
if text == "exit\n" {
|
||||
break
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,11 +45,10 @@ func initHub() (eventhub.Client, []string) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
aadToken, err := getEventHubsTokenProvider()
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
provider := aad.NewProvider(aadToken)
|
||||
hub, err := eventhub.NewClient(namespace, HubName, provider)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -62,30 +64,6 @@ func mustGetenv(key string) string {
|
|||
return v
|
||||
}
|
||||
|
||||
func getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
|
||||
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
|
||||
return getTokenProvider("https://eventhubs.azure.net/")
|
||||
}
|
||||
|
||||
func getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, mustGetenv("AZURE_TENANT_ID"))
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, mustGetenv("AZURE_CLIENT_ID"), mustGetenv("AZURE_CLIENT_SECRET"), resourceURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tokenProvider.Refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
|
||||
namespace := mustGetenv("EVENTHUB_NAMESPACE")
|
||||
client := getEventHubMgmtClient()
|
||||
|
@ -111,11 +89,11 @@ func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
|
|||
func getEventHubMgmtClient() *mgmt.EventHubsClient {
|
||||
subID := mustGetenv("AZURE_SUBSCRIPTION_ID")
|
||||
client := mgmt.NewEventHubsClientWithBaseURI(azure.PublicCloud.ResourceManagerEndpoint, subID)
|
||||
armToken, err := getTokenProvider(azure.PublicCloud.ResourceManagerEndpoint)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
|
||||
client.Authorizer = a
|
||||
return &client
|
||||
}
|
||||
|
||||
|
|
116
aad/jwt.go
116
aad/jwt.go
|
@ -1,11 +1,23 @@
|
|||
package aad
|
||||
|
||||
import (
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/pkcs12"
|
||||
)
|
||||
|
||||
const (
|
||||
resource = "https://eventhubs.azure.net/"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -22,11 +34,95 @@ func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
|
|||
}
|
||||
}
|
||||
|
||||
// NewProviderFromEnvironment builds a new TokenProvider using environment variable available
|
||||
//
|
||||
// 1. Client Credentials: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID" and
|
||||
// "AZURE_CLIENT_SECRET"
|
||||
//
|
||||
// 2. Client Certificate: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID",
|
||||
// "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD"
|
||||
//
|
||||
// 3. Managed Service Identity (MSI): attempt to authenticate via MSI
|
||||
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
|
||||
tenantID := os.Getenv("AZURE_TENANT_ID")
|
||||
clientID := os.Getenv("AZURE_CLIENT_ID")
|
||||
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
|
||||
certificatePath := os.Getenv("AZURE_CERTIFICATE_PATH")
|
||||
certificatePassword := os.Getenv("AZURE_CERTIFICATE_PASSWORD")
|
||||
envName := os.Getenv("AZURE_ENVIRONMENT")
|
||||
|
||||
var env azure.Environment
|
||||
if envName == "" {
|
||||
env = azure.PublicCloud
|
||||
} else {
|
||||
var err error
|
||||
env, err = azure.EnvironmentFromName(envName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1.Client Credentials
|
||||
if clientSecret != "" {
|
||||
log.Debug("creating a token via a service principal client secret")
|
||||
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, resource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err)
|
||||
}
|
||||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
}
|
||||
|
||||
// 2. Client Certificate
|
||||
if certificatePath != "" {
|
||||
log.Debug("creating a token via a service principal client certificate")
|
||||
certData, err := ioutil.ReadFile(certificatePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
|
||||
}
|
||||
certificate, rsaPrivateKey, err := decodePkcs12(certData, certificatePassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
|
||||
}
|
||||
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, resource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err)
|
||||
}
|
||||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
}
|
||||
|
||||
// 3. By default return MSI
|
||||
log.Debug("creating a token via MSI")
|
||||
msiEndpoint, err := adal.GetMSIVMEndpoint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
spToken, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get oauth token from MSI: %v", err)
|
||||
}
|
||||
if err := spToken.Refresh(); err != nil {
|
||||
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
|
||||
}
|
||||
return NewProvider(spToken), nil
|
||||
}
|
||||
|
||||
// GetToken gets a CBS JWT token
|
||||
func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
|
||||
token := t.tokenProvider.Token()
|
||||
expireTicks, err := strconv.Atoi(token.ExpiresOn)
|
||||
if err != nil {
|
||||
log.Debugf("%v", token.AccessToken)
|
||||
return nil, err
|
||||
}
|
||||
currentTicks := time.Now().UTC().Unix()
|
||||
|
@ -42,3 +138,17 @@ func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
|
|||
|
||||
return auth.NewToken(auth.CbsTokenTypeJwt, token.AccessToken, token.ExpiresOn), nil
|
||||
}
|
||||
|
||||
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
|
||||
if !isRsaKey {
|
||||
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
|
||||
}
|
||||
|
||||
return certificate, rsaPrivateKey, nil
|
||||
}
|
||||
|
|
|
@ -2,11 +2,12 @@ package cbs
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/rpc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -19,7 +20,7 @@ const (
|
|||
)
|
||||
|
||||
// NegotiateClaim attempts to put a token to the $cbs management endpoint to negotiate auth for the given audience
|
||||
func NegotiateClaim(audience string, conn *amqp.Client, provider auth.TokenProvider) error {
|
||||
func NegotiateClaim(ctx context.Context, audience string, conn *amqp.Client, provider auth.TokenProvider) error {
|
||||
link, err := rpc.NewLink(conn, cbsAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -42,7 +43,7 @@ func NegotiateClaim(audience string, conn *amqp.Client, provider auth.TokenProvi
|
|||
},
|
||||
}
|
||||
|
||||
res, err := link.RetryableRPC(context.Background(), 3, 1*time.Second, msg)
|
||||
res, err := link.RetryableRPC(ctx, 3, 1*time.Second, msg)
|
||||
if err == nil {
|
||||
log.Debugf("negotiated with response code %d and message: %s", res.Code, res.Description)
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
connStrRegex = regexp.MustCompile(`Endpoint=sb:\/\/(?P<Host>.+?);SharedAccessKeyName=(?P<KeyName>.+?);SharedAccessKey=(?P<Key>.+)`)
|
||||
hostStrRegex = regexp.MustCompile(`^(?P<Namespace>.+?)\..+`)
|
||||
)
|
||||
|
||||
type (
|
||||
// ParsedConn is the structure of a parsed Service Bus or Event Hub connection string.
|
||||
ParsedConn struct {
|
||||
Host string
|
||||
Namespace string
|
||||
KeyName string
|
||||
Key string
|
||||
}
|
||||
)
|
||||
|
||||
// newParsedConnection is a constructor for a parsedConn and verifies each of the inputs is non-null.
|
||||
func newParsedConnection(host, namespace, keyName, key string) (*ParsedConn, error) {
|
||||
if host == "" || keyName == "" || key == "" {
|
||||
return nil, errors.New("connection string contains an empty entry")
|
||||
}
|
||||
return &ParsedConn{
|
||||
Host: "amqps://" + host,
|
||||
Namespace: namespace,
|
||||
KeyName: keyName,
|
||||
Key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ParsedConnectionFromStr takes a string connection string from the Azure portal and returns the parsed representation.
|
||||
func ParsedConnectionFromStr(connStr string) (*ParsedConn, error) {
|
||||
matches := connStrRegex.FindStringSubmatch(connStr)
|
||||
namespaceMatches := hostStrRegex.FindStringSubmatch(matches[1])
|
||||
return newParsedConnection(matches[1], namespaceMatches[1], matches[2], matches[3])
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
namespace = "mynamespace"
|
||||
keyName = "keyName"
|
||||
secret = "superSecret"
|
||||
connStr = "Endpoint=sb://" + namespace + ".servicebus.windows.net/;SharedAccessKeyName=" + keyName + ";SharedAccessKey=" + secret
|
||||
)
|
||||
|
||||
func TestParsedConnectionFromStr(t *testing.T) {
|
||||
parsed, err := ParsedConnectionFromStr(connStr)
|
||||
assert.Nil(t, err, err)
|
||||
assert.Equal(t, "amqps://"+namespace+".servicebus.windows.net/", parsed.Host)
|
||||
assert.Equal(t, namespace, parsed.Namespace)
|
||||
assert.Equal(t, keyName, parsed.KeyName)
|
||||
assert.Equal(t, secret, parsed.Key)
|
||||
}
|
87
hub.go
87
hub.go
|
@ -3,13 +3,19 @@ package eventhub
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/mgmt"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/pkg/errors"
|
||||
"pack.ag/amqp"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/mgmt"
|
||||
"github.com/Azure/azure-event-hubs-go/persist"
|
||||
"github.com/Azure/azure-event-hubs-go/sas"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,7 +32,7 @@ type (
|
|||
senderPartitionID *string
|
||||
receiverMu sync.Mutex
|
||||
senderMu sync.Mutex
|
||||
offsetPersister OffsetPersister
|
||||
offsetPersister persist.OffsetPersister
|
||||
userAgent string
|
||||
}
|
||||
|
||||
|
@ -40,7 +46,7 @@ type (
|
|||
|
||||
// Receiver provides the ability to receive messages
|
||||
Receiver interface {
|
||||
Receive(partitionID string, handler Handler, opts ...ReceiveOption) error
|
||||
Receive(ctx context.Context, partitionID string, handler Handler, opts ...ReceiveOption) error
|
||||
}
|
||||
|
||||
// Closer provides the ability to close a connection or client
|
||||
|
@ -64,13 +70,6 @@ type (
|
|||
|
||||
// HubOption provides structure for configuring new Event Hub instances
|
||||
HubOption func(h *hub) error
|
||||
|
||||
// OffsetPersister provides persistence for the received offset for a given namespace, hub name, consumer group, partition Id and
|
||||
// offset so that if a receiver where to be interrupted, it could resume after the last consumed event.
|
||||
OffsetPersister interface {
|
||||
Write(namespace, name, consumerGroup, partitionID, offset string) error
|
||||
Read(namespace, name, consumerGroup, partitionID string) (string, error)
|
||||
}
|
||||
)
|
||||
|
||||
// NewClient creates a new Event Hub client for sending and receiving messages
|
||||
|
@ -79,7 +78,7 @@ func NewClient(namespace, name string, tokenProvider auth.TokenProvider, opts ..
|
|||
h := &hub{
|
||||
name: name,
|
||||
namespace: ns,
|
||||
offsetPersister: new(MemoryPersister),
|
||||
offsetPersister: new(persist.MemoryPersister),
|
||||
userAgent: rootUserAgent,
|
||||
}
|
||||
|
||||
|
@ -93,6 +92,45 @@ func NewClient(namespace, name string, tokenProvider auth.TokenProvider, opts ..
|
|||
return h, nil
|
||||
}
|
||||
|
||||
// NewClientFromEnvironment creates a new Event Hub client for sending and receiving messages from environment variables
|
||||
func NewClientFromEnvironment(opts ...HubOption) (Client, error) {
|
||||
const envErrMsg = "environment var %s must not be empty"
|
||||
var namespace, name string
|
||||
var provider auth.TokenProvider
|
||||
|
||||
if namespace = os.Getenv("EVENTHUB_NAMESPACE"); namespace == "" {
|
||||
return nil, errors.Errorf(envErrMsg, "EVENTHUB_NAMESPACE")
|
||||
}
|
||||
|
||||
if name = os.Getenv("EVENTHUB_NAME"); name == "" {
|
||||
return nil, errors.Errorf(envErrMsg, "EVENTHUB_NAME")
|
||||
}
|
||||
|
||||
aadProvider, aadErr := aad.NewProviderFromEnvironment()
|
||||
sasProvider, sasErr := sas.NewProviderFromEnvironment()
|
||||
|
||||
if aadErr != nil && sasErr != nil {
|
||||
// both failed
|
||||
log.Debug("both token providers failed")
|
||||
return nil, errors.Errorf("neither Azure Active Directory nor SAS token provider could be built - AAD error: %v, SAS error: %v", aadErr, sasErr)
|
||||
}
|
||||
|
||||
if aadProvider != nil {
|
||||
log.Debug("using AAD provider")
|
||||
provider = aadProvider
|
||||
} else {
|
||||
log.Debug("using SAS provider")
|
||||
provider = sasProvider
|
||||
}
|
||||
|
||||
h, err := NewClient(namespace, name, provider, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// GetRuntimeInformation fetches runtime information from the Event Hub management node
|
||||
func (h *hub) GetRuntimeInformation(ctx context.Context) (*mgmt.HubRuntimeInformation, error) {
|
||||
client := mgmt.NewClient(h.namespace.name, h.name, h.namespace.tokenProvider, h.namespace.environment)
|
||||
|
@ -123,18 +161,21 @@ func (h *hub) GetPartitionInformation(ctx context.Context, partitionID string) (
|
|||
|
||||
// Close drains and closes all of the existing senders, receivers and connections
|
||||
func (h *hub) Close() error {
|
||||
var lastErr error
|
||||
for _, r := range h.receivers {
|
||||
r.Close()
|
||||
if err := r.Close(); err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Listen subscribes for messages sent to the provided entityPath.
|
||||
func (h *hub) Receive(partitionID string, handler Handler, opts ...ReceiveOption) error {
|
||||
func (h *hub) Receive(ctx context.Context, partitionID string, handler Handler, opts ...ReceiveOption) error {
|
||||
h.receiverMu.Lock()
|
||||
defer h.receiverMu.Unlock()
|
||||
|
||||
receiver, err := h.newReceiver(partitionID, opts...)
|
||||
receiver, err := h.newReceiver(ctx, partitionID, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -146,7 +187,7 @@ func (h *hub) Receive(partitionID string, handler Handler, opts ...ReceiveOption
|
|||
|
||||
// Send sends an AMQP message to the broker
|
||||
func (h *hub) Send(ctx context.Context, message *amqp.Message, opts ...SendOption) error {
|
||||
sender, err := h.getSender()
|
||||
sender, err := h.getSender(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -168,7 +209,7 @@ func HubWithPartitionedSender(partitionID string) HubOption {
|
|||
|
||||
// HubWithOffsetPersistence configures the hub instance to read and write offsets so that if a hub is interrupted, it
|
||||
// can resume after the last consumed event.
|
||||
func HubWithOffsetPersistence(offsetPersister OffsetPersister) HubOption {
|
||||
func HubWithOffsetPersistence(offsetPersister persist.OffsetPersister) HubOption {
|
||||
return func(h *hub) error {
|
||||
h.offsetPersister = offsetPersister
|
||||
return nil
|
||||
|
@ -195,12 +236,12 @@ func (h *hub) appendAgent(userAgent string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h *hub) getSender() (*sender, error) {
|
||||
func (h *hub) getSender(ctx context.Context) (*sender, error) {
|
||||
h.senderMu.Lock()
|
||||
defer h.senderMu.Unlock()
|
||||
|
||||
if h.sender == nil {
|
||||
s, err := h.newSender()
|
||||
s, err := h.newSender(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
112
hub_test.go
112
hub_test.go
|
@ -3,28 +3,59 @@ package eventhub
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"math/rand"
|
||||
"pack.ag/amqp"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/aad"
|
||||
"github.com/Azure/azure-event-hubs-go/sas"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
||||
func (suite *eventHubSuite) TestSasToken() {
|
||||
tests := map[string]func(*testing.T, Client, []string, string){
|
||||
"TestMultiSendAndReceive": testMultiSendAndReceive,
|
||||
"TestHubRuntimeInformation": testHubRuntimeInformation,
|
||||
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
|
||||
}
|
||||
|
||||
for name, testFunc := range tests {
|
||||
setupTestTeardown := func(t *testing.T) {
|
||||
hubName := randomName("goehtest", 10)
|
||||
mgmtHub, err := suite.ensureEventHub(context.Background(), hubName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
provider, err := sas.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
suite.T().Run(name, setupTestTeardown)
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) TestPartitionedSender() {
|
||||
tests := map[string]func(*testing.T, Client, string){
|
||||
"TestSend": testBasicSend,
|
||||
"TestSendAndReceive": testBasicSendAndReceive,
|
||||
}
|
||||
|
||||
token, err := suite.getEventHubsTokenProvider()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
token.Refresh()
|
||||
|
||||
for name, testFunc := range tests {
|
||||
setupTestTeardown := func(t *testing.T) {
|
||||
hubName := randomName("goehtest", 10)
|
||||
|
@ -34,7 +65,10 @@ func (suite *eventHubSuite) TestPartitionedSender() {
|
|||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
partitionID := (*mgmtHub.PartitionIds)[0]
|
||||
provider := aad.NewProvider(token)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider, HubWithPartitionedSender(partitionID))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -51,7 +85,7 @@ func (suite *eventHubSuite) TestPartitionedSender() {
|
|||
}
|
||||
|
||||
func testBasicSend(t *testing.T, client Client, _ string) {
|
||||
err := client.Send(context.Background(), amqp.NewMessage([]byte("Hello!")) )
|
||||
err := client.Send(context.Background(), amqp.NewMessage([]byte("Hello!")))
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
|
@ -75,12 +109,14 @@ func testBasicSendAndReceive(t *testing.T, client Client, partitionID string) {
|
|||
}
|
||||
|
||||
count := 0
|
||||
err := client.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
err := client.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
assert.Equal(t, messages[count], string(msg.Data[0]))
|
||||
count++
|
||||
wg.Done()
|
||||
return nil
|
||||
}, ReceiveWithPrefetchCount(100))
|
||||
cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -88,16 +124,10 @@ func testBasicSendAndReceive(t *testing.T, client Client, partitionID string) {
|
|||
}
|
||||
|
||||
func (suite *eventHubSuite) TestMultiPartition() {
|
||||
tests := map[string]func(*testing.T, Client, []string){
|
||||
tests := map[string]func(*testing.T, Client, []string, string){
|
||||
"TestMultiSendAndReceive": testMultiSendAndReceive,
|
||||
}
|
||||
|
||||
token, err := suite.getEventHubsTokenProvider()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
token.Refresh()
|
||||
|
||||
for name, testFunc := range tests {
|
||||
setupTestTeardown := func(t *testing.T) {
|
||||
hubName := randomName("goehtest", 10)
|
||||
|
@ -106,13 +136,16 @@ func (suite *eventHubSuite) TestMultiPartition() {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
provider := aad.NewProvider(token)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
testFunc(t, client, *mgmtHub.PartitionIds)
|
||||
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
|
||||
if err := client.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -122,7 +155,7 @@ func (suite *eventHubSuite) TestMultiPartition() {
|
|||
}
|
||||
}
|
||||
|
||||
func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string) {
|
||||
func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string, _ string) {
|
||||
numMessages := rand.Intn(100) + 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numMessages)
|
||||
|
@ -141,8 +174,9 @@ func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string)
|
|||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
for _, partitionID := range partitionIDs {
|
||||
err := client.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
err := client.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
wg.Done()
|
||||
return nil
|
||||
}, ReceiveWithPrefetchCount(100))
|
||||
|
@ -150,6 +184,7 @@ func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string)
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
|
@ -159,12 +194,6 @@ func (suite *eventHubSuite) TestHubManagement() {
|
|||
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
|
||||
}
|
||||
|
||||
token, err := suite.getEventHubsTokenProvider()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
token.Refresh()
|
||||
|
||||
for name, testFunc := range tests {
|
||||
setupTestTeardown := func(t *testing.T) {
|
||||
hubName := randomName("goehtest", 10)
|
||||
|
@ -173,7 +202,10 @@ func (suite *eventHubSuite) TestHubManagement() {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer suite.deleteEventHub(context.Background(), hubName)
|
||||
provider := aad.NewProvider(token)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client, err := NewClient(suite.namespace, hubName, provider)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -208,6 +240,13 @@ func testHubPartitionRuntimeInformation(t *testing.T, client Client, partitionID
|
|||
assert.Equal(t, "-1", info.LastEnqueuedOffset) // brand new, so should be very last
|
||||
}
|
||||
|
||||
func TestEnvironmentalCreation(t *testing.T) {
|
||||
os.Setenv("EVENTHUB_NAME", "foo")
|
||||
_, err := NewClientFromEnvironment()
|
||||
assert.Nil(t, err)
|
||||
os.Unsetenv("EVENTHUB_NAME")
|
||||
}
|
||||
|
||||
func BenchmarkReceive(b *testing.B) {
|
||||
suite := new(eventHubSuite)
|
||||
suite.SetupSuite()
|
||||
|
@ -225,8 +264,10 @@ func BenchmarkReceive(b *testing.B) {
|
|||
messages[i] = randomName("hello", 10)
|
||||
}
|
||||
|
||||
token, err := suite.getEventHubsTokenProvider()
|
||||
provider := aad.NewProvider(token)
|
||||
provider, err := aad.NewProviderFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
hub, err := NewClient(suite.namespace, *mgmtHub.Name, provider)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
@ -248,9 +289,10 @@ func BenchmarkReceive(b *testing.B) {
|
|||
|
||||
b.ResetTimer()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// receive from all partition IDs
|
||||
for _, partitionID := range *mgmtHub.PartitionIds {
|
||||
err = hub.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
err = hub.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
|
||||
wg.Done()
|
||||
return nil
|
||||
}, ReceiveWithPrefetchCount(100))
|
||||
|
@ -258,7 +300,7 @@ func BenchmarkReceive(b *testing.B) {
|
|||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
b.StopTimer()
|
||||
}
|
||||
|
|
20
mgmt/mgmt.go
20
mgmt/mgmt.go
|
@ -3,12 +3,13 @@ package mgmt
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/rpc"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/pkg/errors"
|
||||
"pack.ag/amqp"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -103,7 +104,7 @@ func (c *Client) GetHubRuntimeInformation(ctx context.Context, conn *amqp.Client
|
|||
entityNameKey: c.hubName,
|
||||
},
|
||||
}
|
||||
err = c.addSecurityToken(msg)
|
||||
msg, err = c.addSecurityToken(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -135,7 +136,7 @@ func (c *Client) GetHubPartitionRuntimeInformation(ctx context.Context, conn *am
|
|||
partitionNameKey: partitionID,
|
||||
},
|
||||
}
|
||||
err = c.addSecurityToken(msg)
|
||||
msg, err = c.addSecurityToken(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -152,13 +153,16 @@ func (c *Client) GetHubPartitionRuntimeInformation(ctx context.Context, conn *am
|
|||
return hubPartitionRuntimeInfo, nil
|
||||
}
|
||||
|
||||
func (c *Client) addSecurityToken(msg *amqp.Message) error {
|
||||
func (c *Client) addSecurityToken(msg *amqp.Message) (*amqp.Message, error) {
|
||||
// TODO (devigned): need to uncomment this functionality after getting some guidance from the Event Hubs team (only works for SAS tokens right now)
|
||||
|
||||
//token, err := c.tokenProvider.GetToken(c.getTokenAudience())
|
||||
//if err != nil {
|
||||
// return nil
|
||||
// return nil, err
|
||||
//}
|
||||
//msg.ApplicationProperties[securityTokenKey] = token.Token
|
||||
return nil
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (c *Client) getTokenAudience() string {
|
||||
|
@ -228,10 +232,6 @@ func newHubRuntimeInformation(msg *amqp.Message) (*HubRuntimeInformation, error)
|
|||
}
|
||||
|
||||
if createdAt, ok := values[resultCreatedAtKey].(time.Time); ok {
|
||||
//t, err := time.Parse("UnixDate", createdAt)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
info.CreatedAt = createdAt
|
||||
} else {
|
||||
return nil, errors.Errorf(errMsgFmt, resultCreatedAtKey, values)
|
||||
|
|
16
namespace.go
16
namespace.go
|
@ -1,14 +1,15 @@
|
|||
package eventhub
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/cbs"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -48,7 +49,8 @@ func (ns *namespace) connection() (*amqp.Client, error) {
|
|||
amqp.ConnProperty("version", "0.0.1"),
|
||||
amqp.ConnProperty("platform", runtime.GOOS),
|
||||
amqp.ConnProperty("framework", runtime.Version()),
|
||||
amqp.ConnProperty("user-agent", rootUserAgent))
|
||||
amqp.ConnProperty("user-agent", rootUserAgent),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,17 +59,17 @@ func (ns *namespace) connection() (*amqp.Client, error) {
|
|||
return ns.client, nil
|
||||
}
|
||||
|
||||
func (ns *namespace) negotiateClaim(entityPath string) error {
|
||||
func (ns *namespace) negotiateClaim(ctx context.Context, entityPath string) error {
|
||||
audience := ns.getEntityAudience(entityPath)
|
||||
conn, err := ns.connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cbs.NegotiateClaim(audience, conn, ns.tokenProvider)
|
||||
return cbs.NegotiateClaim(ctx, audience, conn, ns.tokenProvider)
|
||||
}
|
||||
|
||||
func (ns *namespace) getAmqpHostURI() string {
|
||||
return fmt.Sprintf("amqps://%s.%s/", ns.name, ns.environment.ServiceBusEndpointSuffix)
|
||||
return "amqps://" + ns.name + "." + ns.environment.ServiceBusEndpointSuffix + "/"
|
||||
}
|
||||
|
||||
func (ns *namespace) getEntityAudience(entityPath string) string {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package eventhub
|
||||
package persist
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
@ -7,6 +7,13 @@ import (
|
|||
)
|
||||
|
||||
type (
|
||||
// OffsetPersister provides persistence for the received offset for a given namespace, hub name, consumer group, partition Id and
|
||||
// offset so that if a receiver where to be interrupted, it could resume after the last consumed event.
|
||||
OffsetPersister interface {
|
||||
Write(namespace, name, consumerGroup, partitionID, offset string) error
|
||||
Read(namespace, name, consumerGroup, partitionID string) (string, error)
|
||||
}
|
||||
|
||||
// MemoryPersister is a default implementation of a Hub OffsetPersister, which will persist offset information in
|
||||
// memory.
|
||||
MemoryPersister struct {
|
107
receiver.go
107
receiver.go
|
@ -3,11 +3,11 @@ package eventhub
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net"
|
||||
"pack.ag/amqp"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/persist"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -36,7 +36,7 @@ type (
|
|||
consumerGroup string
|
||||
partitionID string
|
||||
prefetchCount uint32
|
||||
done chan struct{}
|
||||
done func()
|
||||
lastReceivedOffset atomic.Value
|
||||
}
|
||||
|
||||
|
@ -53,8 +53,6 @@ func ReceiveWithConsumerGroup(consumerGroup string) ReceiveOption {
|
|||
}
|
||||
|
||||
// ReceiveWithStartingOffset configures the receiver to start at a given position in the event stream
|
||||
//
|
||||
// This setting will be overridden by the Hub's OffsetPersister if an offset can be read.
|
||||
func ReceiveWithStartingOffset(offset string) ReceiveOption {
|
||||
return func(receiver *receiver) error {
|
||||
receiver.storeLastReceivedOffset(offset)
|
||||
|
@ -62,6 +60,14 @@ func ReceiveWithStartingOffset(offset string) ReceiveOption {
|
|||
}
|
||||
}
|
||||
|
||||
// ReceiveWithLatestOffset configures the receiver to start at a given position in the event stream
|
||||
func ReceiveWithLatestOffset() ReceiveOption {
|
||||
return func(receiver *receiver) error {
|
||||
receiver.storeLastReceivedOffset(EndOfStream)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReceiveWithPrefetchCount configures the receiver to attempt to fetch as many messages as the prefetch amount
|
||||
func ReceiveWithPrefetchCount(prefetch uint32) ReceiveOption {
|
||||
return func(receiver *receiver) error {
|
||||
|
@ -71,13 +77,12 @@ func ReceiveWithPrefetchCount(prefetch uint32) ReceiveOption {
|
|||
}
|
||||
|
||||
// newReceiver creates a new Service Bus message listener given an AMQP client and an entity path
|
||||
func (h *hub) newReceiver(partitionID string, opts ...ReceiveOption) (*receiver, error) {
|
||||
func (h *hub) newReceiver(ctx context.Context, partitionID string, opts ...ReceiveOption) (*receiver, error) {
|
||||
receiver := &receiver{
|
||||
hub: h,
|
||||
consumerGroup: DefaultConsumerGroup,
|
||||
prefetchCount: defaultPrefetchCount,
|
||||
partitionID: partitionID,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
@ -87,38 +92,33 @@ func (h *hub) newReceiver(partitionID string, opts ...ReceiveOption) (*receiver,
|
|||
}
|
||||
|
||||
log.Debugf("creating a new receiver for entity path %s", receiver.getAddress())
|
||||
err := receiver.newSessionAndLink()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return receiver, nil
|
||||
err := receiver.newSessionAndLink(ctx)
|
||||
return receiver, err
|
||||
}
|
||||
|
||||
// Close will close the AMQP session and link of the receiver
|
||||
func (r *receiver) Close() error {
|
||||
close(r.done)
|
||||
if r.done != nil {
|
||||
r.done()
|
||||
}
|
||||
|
||||
err := r.receiver.Close()
|
||||
if err != nil {
|
||||
_ = r.session.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.session.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return r.session.Close()
|
||||
}
|
||||
|
||||
// Recover will attempt to close the current session and link, then rebuild them
|
||||
func (r *receiver) Recover() error {
|
||||
func (r *receiver) Recover(ctx context.Context) error {
|
||||
err := r.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.newSessionAndLink()
|
||||
err = r.newSessionAndLink(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -128,20 +128,20 @@ func (r *receiver) Recover() error {
|
|||
|
||||
// Listen start a listener for messages sent to the entity path
|
||||
func (r *receiver) Listen(handler Handler) {
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
r.done = done
|
||||
messages := make(chan *amqp.Message)
|
||||
go r.listenForMessages(messages)
|
||||
go r.handleMessages(messages, handler)
|
||||
go r.listenForMessages(ctx, messages)
|
||||
go r.handleMessages(ctx, messages, handler)
|
||||
}
|
||||
|
||||
func (r *receiver) handleMessages(messages chan *amqp.Message, handler Handler) {
|
||||
func (r *receiver) handleMessages(ctx context.Context, messages chan *amqp.Message, handler Handler) {
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
case <-ctx.Done():
|
||||
log.Debug("done handling messages")
|
||||
close(messages)
|
||||
return
|
||||
case msg := <-messages:
|
||||
ctx := context.Background()
|
||||
id := messageID(msg)
|
||||
log.Debugf("message id: %v is being passed to handler", id)
|
||||
|
||||
|
@ -149,48 +149,39 @@ func (r *receiver) handleMessages(messages chan *amqp.Message, handler Handler)
|
|||
if err != nil {
|
||||
msg.Reject()
|
||||
log.Debugf("message rejected: id: %v", id)
|
||||
} else {
|
||||
// Accept message
|
||||
msg.Accept()
|
||||
log.Debugf("message accepted: id: %v", id)
|
||||
continue
|
||||
}
|
||||
msg.Accept()
|
||||
log.Debugf("message accepted: id: %v", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *receiver) listenForMessages(msgChan chan *amqp.Message) {
|
||||
func (r *receiver) listenForMessages(ctx context.Context, msgChan chan *amqp.Message) {
|
||||
for {
|
||||
select {
|
||||
case <-r.done:
|
||||
log.Debug("done listening for messages")
|
||||
return
|
||||
default:
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
msg, err := r.receiver.Receive(waitCtx)
|
||||
cancel()
|
||||
|
||||
if err == amqp.ErrLinkClosed || err == amqp.ErrSessionClosed {
|
||||
log.Debug("done listening for messages due to link or session closed")
|
||||
time.Sleep(10 * time.Second)
|
||||
return
|
||||
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
log.Debug("attempting to receive messages timed out")
|
||||
continue
|
||||
} else if err != nil {
|
||||
log.Error(err)
|
||||
msg, err := r.receiver.Receive(ctx)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
r.receivedMessage(msg)
|
||||
msgChan <- msg
|
||||
id := messageID(msg)
|
||||
log.Debugf("Message received: %s", id)
|
||||
select {
|
||||
case msgChan <- msg:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newSessionAndLink will replace the session and link on the receiver
|
||||
func (r *receiver) newSessionAndLink() error {
|
||||
func (r *receiver) newSessionAndLink(ctx context.Context) error {
|
||||
address := r.getAddress()
|
||||
err := r.hub.namespace.negotiateClaim(address)
|
||||
err := r.hub.namespace.negotiateClaim(ctx, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -258,7 +249,7 @@ func (r *receiver) hubName() string {
|
|||
return r.hub.name
|
||||
}
|
||||
|
||||
func (r *receiver) offsetPersister() OffsetPersister {
|
||||
func (r *receiver) offsetPersister() persist.OffsetPersister {
|
||||
return r.hub.offsetPersister
|
||||
}
|
||||
|
||||
|
@ -280,7 +271,7 @@ func (r *receiver) receivedMessage(msg *amqp.Message) {
|
|||
}
|
||||
|
||||
func messageID(msg *amqp.Message) interface{} {
|
||||
id := interface{}("null")
|
||||
var id interface{} = "null"
|
||||
if msg.Properties != nil {
|
||||
id = msg.Properties.MessageID
|
||||
}
|
||||
|
|
82
rpc/rpc.go
82
rpc/rpc.go
|
@ -3,14 +3,15 @@ package rpc
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/common"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/satori/go.uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -56,7 +57,8 @@ func NewLink(conn *amqp.Client, address string) (*Link, error) {
|
|||
clientAddress := strings.Replace("$", "", address, -1) + replyPostfix + id
|
||||
authReceiver, err := authSession.NewReceiver(
|
||||
amqp.LinkSourceAddress(address),
|
||||
amqp.LinkTargetAddress(clientAddress))
|
||||
amqp.LinkTargetAddress(clientAddress),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -79,26 +81,19 @@ func (l *Link) RetryableRPC(ctx context.Context, times int, delay time.Duration,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if res.ServerError() {
|
||||
switch {
|
||||
case res.Code >= 200 && res.Code < 300:
|
||||
log.Debugf("successful rpc on link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
return res, nil
|
||||
case res.Code >= 500:
|
||||
errMessage := fmt.Sprintf("server error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
log.Debugln(errMessage)
|
||||
return nil, &common.Retryable{Message: errMessage}
|
||||
}
|
||||
|
||||
if res.ClientError() {
|
||||
errMessage := fmt.Sprintf("client error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
default:
|
||||
errMessage := fmt.Sprintf("unhandled error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
log.Debugln(errMessage)
|
||||
return nil, errors.New(errMessage)
|
||||
return nil, &common.Retryable{Message: errMessage}
|
||||
}
|
||||
|
||||
if res.Success() {
|
||||
log.Debugf("successful rpc on link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
errMessage := fmt.Sprintf("unhandled error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
|
||||
log.Debugln(errMessage)
|
||||
return nil, &common.Retryable{Message: errMessage}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -145,32 +140,39 @@ func (l *Link) RPC(ctx context.Context, msg *amqp.Message) (*Response, error) {
|
|||
return response, err
|
||||
}
|
||||
|
||||
// Close the link sender, receiver and session
|
||||
func (l *Link) Close() {
|
||||
if l.sender != nil {
|
||||
l.sender.Close()
|
||||
// Close the link receiver, sender and session
|
||||
func (l *Link) Close() error {
|
||||
if err := l.closeReceiver(); err != nil {
|
||||
_ = l.closeSender()
|
||||
_ = l.closeSession()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := l.closeSender(); err != nil {
|
||||
_ = l.closeSession()
|
||||
return err
|
||||
}
|
||||
|
||||
return l.closeSession()
|
||||
}
|
||||
|
||||
func (l *Link) closeReceiver() error {
|
||||
if l.receiver != nil {
|
||||
l.receiver.Close()
|
||||
return l.receiver.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) closeSender() error {
|
||||
if l.sender != nil {
|
||||
return l.sender.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) closeSession() error {
|
||||
if l.session != nil {
|
||||
l.session.Close()
|
||||
return l.session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Success return true if the status code is between 200 and 300
|
||||
func (r *Response) Success() bool {
|
||||
return r.Code >= 200 && r.Code < 300
|
||||
}
|
||||
|
||||
// ServerError is true when status code is 500 or greater
|
||||
func (r *Response) ServerError() bool {
|
||||
return r.Code >= 500
|
||||
}
|
||||
|
||||
// ClientError is true when status code is in the 400s
|
||||
func (r *Response) ClientError() bool {
|
||||
return r.Code >= 400 && r.Code < 500
|
||||
return nil
|
||||
}
|
||||
|
|
42
sas/sas.go
42
sas/sas.go
|
@ -5,11 +5,15 @@ import (
|
|||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-event-hubs-go/auth"
|
||||
"github.com/Azure/azure-event-hubs-go/common"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -33,6 +37,36 @@ func NewProvider(namespace, keyName, key string) auth.TokenProvider {
|
|||
}
|
||||
}
|
||||
|
||||
// NewProviderFromEnvironment creates a new SAS TokenProvider from environment variables
|
||||
//
|
||||
// Expected Environment Variables:
|
||||
// - "EVENTHUB_NAMESPACE" the namespace of the Event Hub instance
|
||||
// - "EVENTHUB_KEY_NAME" the name of the Event Hub key
|
||||
// - "EVENTHUB_KEY_VALUE" the secret for the Event Hub key named in "EVENTHUB_KEY_NAME"
|
||||
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
|
||||
var provider auth.TokenProvider
|
||||
keyName := os.Getenv("EVENTHUB_KEY_NAME")
|
||||
keyValue := os.Getenv("EVENTHUB_KEY_VALUE")
|
||||
namespace := os.Getenv("EVENTHUB_NAMESPACE")
|
||||
connStr := os.Getenv("EVENTHUB_CONNECTION_STRING")
|
||||
|
||||
if (keyName == "" || keyValue == "" || namespace == "") && connStr == "" {
|
||||
return nil, errors.New("unable to build SAS token provider because (EVENTHUB_KEY_NAME, EVENTHUB_KEY_VALUE and EVENTHUB_NAMESPACE) were empty, and EVENTHUB_CONNECTION_STRING was empty")
|
||||
}
|
||||
|
||||
if connStr != "" {
|
||||
parsed, err := common.ParsedConnectionFromStr(connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider = NewProvider(parsed.Namespace, parsed.KeyName, parsed.Key)
|
||||
} else {
|
||||
provider = NewProvider(namespace, keyName, keyValue)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetToken gets a CBS SAS token
|
||||
func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
|
||||
signed, expiry := t.signer.SignWithDuration(audience, 2*time.Hour)
|
||||
|
@ -56,10 +90,10 @@ func (s *Signer) SignWithDuration(uri string, interval time.Duration) (signed, e
|
|||
|
||||
// SignWithExpiry signs a given uri with a given expiry string
|
||||
func (s *Signer) SignWithExpiry(uri, expiry string) string {
|
||||
u := strings.ToLower(url.QueryEscape(uri))
|
||||
sts := stringToSign(u, expiry)
|
||||
audience := strings.ToLower(url.QueryEscape(uri))
|
||||
sts := stringToSign(audience, expiry)
|
||||
sig := s.signString(sts)
|
||||
return fmt.Sprintf("SharedAccessSignature sig=%s&se=%s&skn=%s&sr=%s", sig, expiry, s.keyName, u)
|
||||
return fmt.Sprintf("SharedAccessSignature sr=%s&sig=%s&se=%s&skn=%s", audience, sig, expiry, s.keyName)
|
||||
}
|
||||
|
||||
func signatureExpiry(from time.Time, interval time.Duration) string {
|
||||
|
|
40
sender.go
40
sender.go
|
@ -3,6 +3,7 @@ package eventhub
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"pack.ag/amqp"
|
||||
)
|
||||
|
@ -22,48 +23,33 @@ type (
|
|||
)
|
||||
|
||||
// newSender creates a new Service Bus message sender given an AMQP client and entity path
|
||||
func (h *hub) newSender() (*sender, error) {
|
||||
func (h *hub) newSender(ctx context.Context) (*sender, error) {
|
||||
s := &sender{
|
||||
hub: h,
|
||||
partitionID: h.senderPartitionID,
|
||||
}
|
||||
|
||||
log.Debugf("creating a new sender for entity path %s", s.getAddress())
|
||||
err := s.newSessionAndLink()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
err := s.newSessionAndLink(ctx)
|
||||
return s, err
|
||||
}
|
||||
|
||||
// Recover will attempt to close the current session and link, then rebuild them
|
||||
func (s *sender) Recover() error {
|
||||
func (s *sender) Recover(ctx context.Context) error {
|
||||
err := s.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.newSessionAndLink()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.newSessionAndLink(ctx)
|
||||
}
|
||||
|
||||
// Close will close the AMQP session and link of the sender
|
||||
func (s *sender) Close() error {
|
||||
err := s.sender.Close()
|
||||
if err != nil {
|
||||
_ = s.session.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.session.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.session.Close()
|
||||
}
|
||||
|
||||
// Send will send a message to the entity path with options
|
||||
|
@ -82,11 +68,7 @@ func (s *sender) Send(ctx context.Context, msg *amqp.Message, opts ...SendOption
|
|||
msg.Annotations["x-opt-partition-key"] = s.partitionID
|
||||
}
|
||||
|
||||
err := s.sender.Send(ctx, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return s.sender.Send(ctx, msg)
|
||||
}
|
||||
|
||||
//func (s *sender) SendBatch(ctx context.Context, messages []*amqp.Message) error {
|
||||
|
@ -115,8 +97,8 @@ func (s *sender) prepareMessage(msg *amqp.Message) {
|
|||
}
|
||||
|
||||
// newSessionAndLink will replace the existing session and link
|
||||
func (s *sender) newSessionAndLink() error {
|
||||
err := s.hub.namespace.negotiateClaim(s.getAddress())
|
||||
func (s *sender) newSessionAndLink(ctx context.Context) error {
|
||||
err := s.hub.namespace.negotiateClaim(ctx, s.getAddress())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
package eventhub
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/satori/go.uuid"
|
||||
"pack.ag/amqp"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type (
|
||||
|
@ -20,7 +21,6 @@ func newSession(amqpSession *amqp.Session) *session {
|
|||
return &session{
|
||||
Session: amqpSession,
|
||||
SessionID: uuid.NewV4().String(),
|
||||
counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
131
suite_test.go
131
suite_test.go
|
@ -4,18 +4,18 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
|
||||
rm "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
|
||||
rm "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -32,19 +32,9 @@ type (
|
|||
// eventHubSuite encapsulates a end to end test of Event Hubs with build up and tear down of all EH resources
|
||||
eventHubSuite struct {
|
||||
suite.Suite
|
||||
tenantID string
|
||||
subscriptionID string
|
||||
clientID string
|
||||
clientSecret string
|
||||
namespace string
|
||||
env azure.Environment
|
||||
armToken *adal.ServicePrincipalToken
|
||||
}
|
||||
|
||||
servicePrincipalCredentials struct {
|
||||
TenantID string
|
||||
ApplicationID string
|
||||
Secret string
|
||||
}
|
||||
|
||||
// HubMgmtOption represents an option for configuring an Event Hub.
|
||||
|
@ -67,13 +57,20 @@ func (suite *eventHubSuite) SetupSuite() {
|
|||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
suite.tenantID = mustGetEnv("AZURE_TENANT_ID")
|
||||
suite.subscriptionID = mustGetEnv("AZURE_SUBSCRIPTION_ID")
|
||||
suite.clientID = mustGetEnv("AZURE_CLIENT_ID")
|
||||
suite.clientSecret = mustGetEnv("AZURE_CLIENT_SECRET")
|
||||
suite.namespace = mustGetEnv("EVENTHUB_NAMESPACE")
|
||||
suite.env = azure.PublicCloud
|
||||
suite.armToken = suite.servicePrincipalToken()
|
||||
envName := os.Getenv("AZURE_ENVIRONMENT")
|
||||
|
||||
if envName == "" {
|
||||
suite.env = azure.PublicCloud
|
||||
} else {
|
||||
var err error
|
||||
env, err := azure.EnvironmentFromName(envName)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
suite.env = env
|
||||
}
|
||||
|
||||
err := suite.ensureProvisioned(mgmt.SkuTierStandard)
|
||||
if err != nil {
|
||||
|
@ -86,7 +83,7 @@ func (suite *eventHubSuite) TearDownSuite() {
|
|||
}
|
||||
|
||||
func (suite *eventHubSuite) ensureProvisioned(tier mgmt.SkuTier) error {
|
||||
_, err := ensureResourceGroup(context.Background(), suite.subscriptionID, ResourceGroupName, Location, suite.armToken, suite.env)
|
||||
_, err := ensureResourceGroup(context.Background(), suite.subscriptionID, ResourceGroupName, Location, suite.env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -99,27 +96,9 @@ func (suite *eventHubSuite) ensureProvisioned(tier mgmt.SkuTier) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) servicePrincipalToken() *adal.ServicePrincipalToken {
|
||||
|
||||
oauthConfig, err := adal.NewOAuthConfig(suite.env.ActiveDirectoryEndpoint, suite.tenantID)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig,
|
||||
suite.clientID,
|
||||
suite.clientSecret,
|
||||
suite.env.ResourceManagerEndpoint)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
return tokenProvider
|
||||
}
|
||||
|
||||
// ensureResourceGroup creates a Azure Resource Group if it does not already exist
|
||||
func ensureResourceGroup(ctx context.Context, subscriptionID, name, location string, armToken *adal.ServicePrincipalToken, env azure.Environment) (*rm.Group, error) {
|
||||
groupClient := getRmGroupClientWithToken(subscriptionID, armToken, env)
|
||||
func ensureResourceGroup(ctx context.Context, subscriptionID, name, location string, env azure.Environment) (*rm.Group, error) {
|
||||
groupClient := getRmGroupClientWithToken(subscriptionID, env)
|
||||
group, err := groupClient.Get(ctx, name)
|
||||
|
||||
if group.StatusCode == http.StatusNotFound {
|
||||
|
@ -135,13 +114,13 @@ func ensureResourceGroup(ctx context.Context, subscriptionID, name, location str
|
|||
}
|
||||
|
||||
// ensureNamespace creates a Azure Event Hub Namespace if it does not already exist
|
||||
func ensureNamespace(ctx context.Context, subscriptionID, rg, name, location string, armToken *adal.ServicePrincipalToken, env azure.Environment, opts ...namespaceMgmtOption) (*mgmt.EHNamespace, error) {
|
||||
_, err := ensureResourceGroup(ctx, subscriptionID, rg, location, armToken, env)
|
||||
func ensureNamespace(ctx context.Context, subscriptionID, rg, name, location string, env azure.Environment, opts ...namespaceMgmtOption) (*mgmt.EHNamespace, error) {
|
||||
_, err := ensureResourceGroup(ctx, subscriptionID, rg, location, env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := getNamespaceMgmtClientWithToken(subscriptionID, armToken, env)
|
||||
client := getNamespaceMgmtClientWithToken(subscriptionID, env)
|
||||
namespace, err := client.Get(ctx, rg, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -237,38 +216,54 @@ func (suite *eventHubSuite) deleteEventHub(ctx context.Context, name string) err
|
|||
|
||||
func (suite *eventHubSuite) getEventHubMgmtClient() *mgmt.EventHubsClient {
|
||||
client := mgmt.NewEventHubsClientWithBaseURI(suite.env.ResourceManagerEndpoint, suite.subscriptionID)
|
||||
client.Authorizer = autorest.NewBearerAuthorizer(suite.armToken)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.Authorizer = a
|
||||
return &client
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) getNamespaceMgmtClient() *mgmt.NamespacesClient {
|
||||
return getNamespaceMgmtClientWithToken(suite.subscriptionID, suite.armToken, suite.env)
|
||||
return getNamespaceMgmtClientWithToken(suite.subscriptionID, suite.env)
|
||||
}
|
||||
|
||||
func getNamespaceMgmtClientWithToken(subscriptionID string, armToken *adal.ServicePrincipalToken, env azure.Environment) *mgmt.NamespacesClient {
|
||||
func getNamespaceMgmtClientWithToken(subscriptionID string, env azure.Environment) *mgmt.NamespacesClient {
|
||||
client := mgmt.NewNamespacesClientWithBaseURI(env.ResourceManagerEndpoint, subscriptionID)
|
||||
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.Authorizer = a
|
||||
return &client
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) getNamespaceMgmtClientWithCredentials(ctx context.Context, subscriptionID, rg, name string) *mgmt.NamespacesClient {
|
||||
client := mgmt.NewNamespacesClientWithBaseURI(suite.env.ResourceManagerEndpoint, suite.subscriptionID)
|
||||
client.Authorizer = autorest.NewBearerAuthorizer(suite.armToken)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
client.Authorizer = a
|
||||
return &client
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) getRmGroupClient() *rm.GroupsClient {
|
||||
return getRmGroupClientWithToken(suite.subscriptionID, suite.armToken, suite.env)
|
||||
return getRmGroupClientWithToken(suite.subscriptionID, suite.env)
|
||||
}
|
||||
|
||||
func getRmGroupClientWithToken(subscriptionID string, armToken *adal.ServicePrincipalToken, env azure.Environment) *rm.GroupsClient {
|
||||
func getRmGroupClientWithToken(subscriptionID string, env azure.Environment) *rm.GroupsClient {
|
||||
groupsClient := rm.NewGroupsClientWithBaseURI(env.ResourceManagerEndpoint, subscriptionID)
|
||||
groupsClient.Authorizer = autorest.NewBearerAuthorizer(armToken)
|
||||
a, err := azauth.NewAuthorizerFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
groupsClient.Authorizer = a
|
||||
return &groupsClient
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) ensureResourceGroup() (*rm.Group, error) {
|
||||
group, err := ensureResourceGroup(context.Background(), suite.subscriptionID, suite.namespace, Location, suite.armToken, suite.env)
|
||||
group, err := ensureResourceGroup(context.Background(), suite.subscriptionID, suite.namespace, Location, suite.env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -276,37 +271,13 @@ func (suite *eventHubSuite) ensureResourceGroup() (*rm.Group, error) {
|
|||
}
|
||||
|
||||
func (suite *eventHubSuite) ensureNamespace() (*mgmt.EHNamespace, error) {
|
||||
ns, err := ensureNamespace(context.Background(), suite.subscriptionID, ResourceGroupName, suite.namespace, Location, suite.armToken, suite.env)
|
||||
ns, err := ensureNamespace(context.Background(), suite.subscriptionID, ResourceGroupName, suite.namespace, Location, suite.env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ns, err
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
|
||||
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
|
||||
return suite.getTokenProvider("https://eventhubs.azure.net/")
|
||||
}
|
||||
|
||||
func (suite *eventHubSuite) getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
|
||||
oauthConfig, err := adal.NewOAuthConfig(suite.env.ActiveDirectoryEndpoint, suite.tenantID)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, suite.clientID, suite.clientSecret, resourceURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tokenProvider.Refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenProvider, nil
|
||||
}
|
||||
|
||||
func mustGetEnv(key string) string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
|
|
Загрузка…
Ссылка в новой задаче