Merge pull request #9 from devigned/cleanup

Clean up and add some environmental help
This commit is contained in:
David Justice 2018-02-20 13:32:38 -08:00 коммит произвёл GitHub
Родитель bbbb427294 ac58da63db
Коммит 1d3f2d42fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
19 изменённых файлов: 591 добавлений и 378 удалений

17
Gopkg.lock сгенерированный
Просмотреть файл

@ -17,6 +17,7 @@
"autorest",
"autorest/adal",
"autorest/azure",
"autorest/azure/auth",
"autorest/date",
"autorest/to",
"autorest/validation"
@ -36,6 +37,12 @@
revision = "dbeaa9332f19a944acb5736b4456cfcc02140e29"
version = "v3.1.0"
[[projects]]
branch = "master"
name = "github.com/dimchansky/utfbom"
packages = ["."]
revision = "6c6132ff69f0f6c088739067407b5d32c52e1d0f"
[[projects]]
name = "github.com/pkg/errors"
packages = ["."]
@ -73,8 +80,12 @@
[[projects]]
branch = "master"
name = "golang.org/x/crypto"
packages = ["ssh/terminal"]
revision = "650f4a345ab4e5b245a3034b110ebc7299e68186"
packages = [
"pkcs12",
"pkcs12/internal/rc2",
"ssh/terminal"
]
revision = "432090b8f568c018896cd8a0fb0345872bbac6ce"
[[projects]]
branch = "master"
@ -97,6 +108,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "5f18a64a7717270d06573eadd917fc9fb583eb9953b9811f04fb4bc04a0c0a7f"
inputs-digest = "b7c45c0ec32d34417f3a7f0cd4ad053b98b3bb6c56d189c9585a2e10c2f49a62"
solver-name = "gps-cdcl"
solver-version = 1

Просмотреть файл

@ -17,7 +17,7 @@ DEP = dep
V = 0
Q = $(if $(filter 1,$V),,@)
M = $(shell printf "\033[34;1m▶\033[0m")
TIMEOUT = 100
TIMEOUT = 200
.PHONY: all
all: fmt vendor lint vet | $(BASE) ; $(info $(M) building library) @ ## Build program

Просмотреть файл

@ -1,18 +1,18 @@
package main
import (
"github.com/Azure/azure-event-hubs-go"
"fmt"
"time"
"os"
"github.com/Azure/go-autorest/autorest/azure"
"log"
"context"
"pack.ag/amqp"
"fmt"
"log"
"os"
"time"
"github.com/Azure/azure-event-hubs-go"
"github.com/Azure/azure-event-hubs-go/aad"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest"
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
"github.com/Azure/go-autorest/autorest/azure"
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
"pack.ag/amqp"
)
const (
@ -26,19 +26,21 @@ func main() {
exit := make(chan struct{})
handler := func(ctx context.Context, msg *amqp.Message) error {
text := string(msg.Data)
text := string(msg.Data[0])
if text == "exit\n" {
fmt.Println("Someone told me to exit!")
exit <- *new(struct{})
} else {
fmt.Println(string(msg.Data))
fmt.Println(string(msg.Data[0]))
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
for _, partitionID := range partitions {
hub.Receive(partitionID, handler)
hub.Receive(ctx, partitionID, handler)
}
cancel()
select {
case <-exit:
@ -57,11 +59,10 @@ func initHub() (eventhub.Client, []string) {
log.Fatal(err)
}
aadToken, err := getEventHubsTokenProvider()
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
log.Fatal(err)
}
provider := aad.NewProvider(aadToken)
hub, err := eventhub.NewClient(namespace, HubName, provider)
if err != nil {
panic(err)
@ -77,30 +78,6 @@ func mustGetenv(key string) string {
return v
}
func getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
return getTokenProvider("https://eventhubs.azure.net/")
}
func getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, mustGetenv("AZURE_TENANT_ID"))
if err != nil {
log.Fatalln(err)
}
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, mustGetenv("AZURE_CLIENT_ID"), mustGetenv("AZURE_CLIENT_SECRET"), resourceURI)
if err != nil {
return nil, err
}
err = tokenProvider.Refresh()
if err != nil {
return nil, err
}
return tokenProvider, nil
}
func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
namespace := mustGetenv("EVENTHUB_NAMESPACE")
client := getEventHubMgmtClient()
@ -126,10 +103,10 @@ func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
func getEventHubMgmtClient() *mgmt.EventHubsClient {
subID := mustGetenv("AZURE_SUBSCRIPTION_ID")
client := mgmt.NewEventHubsClientWithBaseURI(azure.PublicCloud.ResourceManagerEndpoint, subID)
armToken, err := getTokenProvider(azure.PublicCloud.ResourceManagerEndpoint)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
client.Authorizer = a
return &client
}

Просмотреть файл

@ -1,18 +1,19 @@
package main
import (
"github.com/Azure/azure-event-hubs-go"
"fmt"
"os"
"github.com/Azure/go-autorest/autorest/azure"
"log"
"context"
"pack.ag/amqp"
"bufio"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest"
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
"context"
"fmt"
"log"
"os"
"time"
"github.com/Azure/azure-event-hubs-go"
"github.com/Azure/azure-event-hubs-go/aad"
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
"github.com/Azure/go-autorest/autorest/azure"
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
"pack.ag/amqp"
)
const (
@ -28,10 +29,12 @@ func main() {
for {
fmt.Print("Enter text: ")
text, _ := reader.ReadString('\n')
hub.Send(context.Background(), &amqp.Message{Data: []byte(text)})
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
hub.Send(ctx, amqp.NewMessage([]byte(text)))
if text == "exit\n" {
break
}
cancel()
}
}
@ -42,11 +45,10 @@ func initHub() (eventhub.Client, []string) {
log.Fatal(err)
}
aadToken, err := getEventHubsTokenProvider()
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
log.Fatal(err)
}
provider := aad.NewProvider(aadToken)
hub, err := eventhub.NewClient(namespace, HubName, provider)
if err != nil {
panic(err)
@ -62,30 +64,6 @@ func mustGetenv(key string) string {
return v
}
func getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
return getTokenProvider("https://eventhubs.azure.net/")
}
func getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, mustGetenv("AZURE_TENANT_ID"))
if err != nil {
log.Fatalln(err)
}
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, mustGetenv("AZURE_CLIENT_ID"), mustGetenv("AZURE_CLIENT_SECRET"), resourceURI)
if err != nil {
return nil, err
}
err = tokenProvider.Refresh()
if err != nil {
return nil, err
}
return tokenProvider, nil
}
func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
namespace := mustGetenv("EVENTHUB_NAMESPACE")
client := getEventHubMgmtClient()
@ -111,11 +89,11 @@ func ensureEventHub(ctx context.Context, name string) (*mgmt.Model, error) {
func getEventHubMgmtClient() *mgmt.EventHubsClient {
subID := mustGetenv("AZURE_SUBSCRIPTION_ID")
client := mgmt.NewEventHubsClientWithBaseURI(azure.PublicCloud.ResourceManagerEndpoint, subID)
armToken, err := getTokenProvider(azure.PublicCloud.ResourceManagerEndpoint)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
client.Authorizer = a
return &client
}

Просмотреть файл

@ -1,11 +1,23 @@
package aad
import (
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/go-autorest/autorest/adal"
log "github.com/sirupsen/logrus"
"crypto/rsa"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"strconv"
"time"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/pkcs12"
)
const (
resource = "https://eventhubs.azure.net/"
)
type (
@ -22,11 +34,95 @@ func NewProvider(tokenProvider *adal.ServicePrincipalToken) auth.TokenProvider {
}
}
// NewProviderFromEnvironment builds a new TokenProvider using environment variable available
//
// 1. Client Credentials: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID" and
// "AZURE_CLIENT_SECRET"
//
// 2. Client Certificate: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID",
// "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD"
//
// 3. Managed Service Identity (MSI): attempt to authenticate via MSI
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
tenantID := os.Getenv("AZURE_TENANT_ID")
clientID := os.Getenv("AZURE_CLIENT_ID")
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
certificatePath := os.Getenv("AZURE_CERTIFICATE_PATH")
certificatePassword := os.Getenv("AZURE_CERTIFICATE_PASSWORD")
envName := os.Getenv("AZURE_ENVIRONMENT")
var env azure.Environment
if envName == "" {
env = azure.PublicCloud
} else {
var err error
env, err = azure.EnvironmentFromName(envName)
if err != nil {
return nil, err
}
}
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, tenantID)
if err != nil {
return nil, err
}
// 1.Client Credentials
if clientSecret != "" {
log.Debug("creating a token via a service principal client secret")
spToken, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err)
}
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
}
// 2. Client Certificate
if certificatePath != "" {
log.Debug("creating a token via a service principal client certificate")
certData, err := ioutil.ReadFile(certificatePath)
if err != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", certificatePath, err)
}
certificate, rsaPrivateKey, err := decodePkcs12(certData, certificatePassword)
if err != nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}
spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err)
}
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
}
// 3. By default return MSI
log.Debug("creating a token via MSI")
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, err
}
spToken, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from MSI: %v", err)
}
if err := spToken.Refresh(); err != nil {
return nil, fmt.Errorf("failed to refersh token: %v", spToken)
}
return NewProvider(spToken), nil
}
// GetToken gets a CBS JWT token
func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
token := t.tokenProvider.Token()
expireTicks, err := strconv.Atoi(token.ExpiresOn)
if err != nil {
log.Debugf("%v", token.AccessToken)
return nil, err
}
currentTicks := time.Now().UTC().Unix()
@ -42,3 +138,17 @@ func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
return auth.NewToken(auth.CbsTokenTypeJwt, token.AccessToken, token.ExpiresOn), nil
}
func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
return nil, nil, err
}
rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey)
if !isRsaKey {
return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key")
}
return certificate, rsaPrivateKey, nil
}

Просмотреть файл

@ -2,11 +2,12 @@ package cbs
import (
"context"
"time"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/rpc"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
"time"
)
const (
@ -19,7 +20,7 @@ const (
)
// NegotiateClaim attempts to put a token to the $cbs management endpoint to negotiate auth for the given audience
func NegotiateClaim(audience string, conn *amqp.Client, provider auth.TokenProvider) error {
func NegotiateClaim(ctx context.Context, audience string, conn *amqp.Client, provider auth.TokenProvider) error {
link, err := rpc.NewLink(conn, cbsAddress)
if err != nil {
return err
@ -42,7 +43,7 @@ func NegotiateClaim(audience string, conn *amqp.Client, provider auth.TokenProvi
},
}
res, err := link.RetryableRPC(context.Background(), 3, 1*time.Second, msg)
res, err := link.RetryableRPC(ctx, 3, 1*time.Second, msg)
if err == nil {
log.Debugf("negotiated with response code %d and message: %s", res.Code, res.Description)
} else {

41
common/conn.go Normal file
Просмотреть файл

@ -0,0 +1,41 @@
package common
import (
"errors"
"regexp"
)
var (
connStrRegex = regexp.MustCompile(`Endpoint=sb:\/\/(?P<Host>.+?);SharedAccessKeyName=(?P<KeyName>.+?);SharedAccessKey=(?P<Key>.+)`)
hostStrRegex = regexp.MustCompile(`^(?P<Namespace>.+?)\..+`)
)
type (
// ParsedConn is the structure of a parsed Service Bus or Event Hub connection string.
ParsedConn struct {
Host string
Namespace string
KeyName string
Key string
}
)
// newParsedConnection is a constructor for a parsedConn and verifies each of the inputs is non-null.
func newParsedConnection(host, namespace, keyName, key string) (*ParsedConn, error) {
if host == "" || keyName == "" || key == "" {
return nil, errors.New("connection string contains an empty entry")
}
return &ParsedConn{
Host: "amqps://" + host,
Namespace: namespace,
KeyName: keyName,
Key: key,
}, nil
}
// ParsedConnectionFromStr takes a string connection string from the Azure portal and returns the parsed representation.
func ParsedConnectionFromStr(connStr string) (*ParsedConn, error) {
matches := connStrRegex.FindStringSubmatch(connStr)
namespaceMatches := hostStrRegex.FindStringSubmatch(matches[1])
return newParsedConnection(matches[1], namespaceMatches[1], matches[2], matches[3])
}

23
common/conn_test.go Normal file
Просмотреть файл

@ -0,0 +1,23 @@
package common
import (
"testing"
"github.com/stretchr/testify/assert"
)
const (
namespace = "mynamespace"
keyName = "keyName"
secret = "superSecret"
connStr = "Endpoint=sb://" + namespace + ".servicebus.windows.net/;SharedAccessKeyName=" + keyName + ";SharedAccessKey=" + secret
)
func TestParsedConnectionFromStr(t *testing.T) {
parsed, err := ParsedConnectionFromStr(connStr)
assert.Nil(t, err, err)
assert.Equal(t, "amqps://"+namespace+".servicebus.windows.net/", parsed.Host)
assert.Equal(t, namespace, parsed.Namespace)
assert.Equal(t, keyName, parsed.KeyName)
assert.Equal(t, secret, parsed.Key)
}

87
hub.go
Просмотреть файл

@ -3,13 +3,19 @@ package eventhub
import (
"context"
"fmt"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/mgmt"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
"pack.ag/amqp"
"os"
"path"
"sync"
"github.com/Azure/azure-event-hubs-go/aad"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/mgmt"
"github.com/Azure/azure-event-hubs-go/persist"
"github.com/Azure/azure-event-hubs-go/sas"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
)
const (
@ -26,7 +32,7 @@ type (
senderPartitionID *string
receiverMu sync.Mutex
senderMu sync.Mutex
offsetPersister OffsetPersister
offsetPersister persist.OffsetPersister
userAgent string
}
@ -40,7 +46,7 @@ type (
// Receiver provides the ability to receive messages
Receiver interface {
Receive(partitionID string, handler Handler, opts ...ReceiveOption) error
Receive(ctx context.Context, partitionID string, handler Handler, opts ...ReceiveOption) error
}
// Closer provides the ability to close a connection or client
@ -64,13 +70,6 @@ type (
// HubOption provides structure for configuring new Event Hub instances
HubOption func(h *hub) error
// OffsetPersister provides persistence for the received offset for a given namespace, hub name, consumer group, partition Id and
// offset so that if a receiver where to be interrupted, it could resume after the last consumed event.
OffsetPersister interface {
Write(namespace, name, consumerGroup, partitionID, offset string) error
Read(namespace, name, consumerGroup, partitionID string) (string, error)
}
)
// NewClient creates a new Event Hub client for sending and receiving messages
@ -79,7 +78,7 @@ func NewClient(namespace, name string, tokenProvider auth.TokenProvider, opts ..
h := &hub{
name: name,
namespace: ns,
offsetPersister: new(MemoryPersister),
offsetPersister: new(persist.MemoryPersister),
userAgent: rootUserAgent,
}
@ -93,6 +92,45 @@ func NewClient(namespace, name string, tokenProvider auth.TokenProvider, opts ..
return h, nil
}
// NewClientFromEnvironment creates a new Event Hub client for sending and receiving messages from environment variables
func NewClientFromEnvironment(opts ...HubOption) (Client, error) {
const envErrMsg = "environment var %s must not be empty"
var namespace, name string
var provider auth.TokenProvider
if namespace = os.Getenv("EVENTHUB_NAMESPACE"); namespace == "" {
return nil, errors.Errorf(envErrMsg, "EVENTHUB_NAMESPACE")
}
if name = os.Getenv("EVENTHUB_NAME"); name == "" {
return nil, errors.Errorf(envErrMsg, "EVENTHUB_NAME")
}
aadProvider, aadErr := aad.NewProviderFromEnvironment()
sasProvider, sasErr := sas.NewProviderFromEnvironment()
if aadErr != nil && sasErr != nil {
// both failed
log.Debug("both token providers failed")
return nil, errors.Errorf("neither Azure Active Directory nor SAS token provider could be built - AAD error: %v, SAS error: %v", aadErr, sasErr)
}
if aadProvider != nil {
log.Debug("using AAD provider")
provider = aadProvider
} else {
log.Debug("using SAS provider")
provider = sasProvider
}
h, err := NewClient(namespace, name, provider, opts...)
if err != nil {
return nil, err
}
return h, nil
}
// GetRuntimeInformation fetches runtime information from the Event Hub management node
func (h *hub) GetRuntimeInformation(ctx context.Context) (*mgmt.HubRuntimeInformation, error) {
client := mgmt.NewClient(h.namespace.name, h.name, h.namespace.tokenProvider, h.namespace.environment)
@ -123,18 +161,21 @@ func (h *hub) GetPartitionInformation(ctx context.Context, partitionID string) (
// Close drains and closes all of the existing senders, receivers and connections
func (h *hub) Close() error {
var lastErr error
for _, r := range h.receivers {
r.Close()
if err := r.Close(); err != nil {
lastErr = err
}
}
return nil
return lastErr
}
// Listen subscribes for messages sent to the provided entityPath.
func (h *hub) Receive(partitionID string, handler Handler, opts ...ReceiveOption) error {
func (h *hub) Receive(ctx context.Context, partitionID string, handler Handler, opts ...ReceiveOption) error {
h.receiverMu.Lock()
defer h.receiverMu.Unlock()
receiver, err := h.newReceiver(partitionID, opts...)
receiver, err := h.newReceiver(ctx, partitionID, opts...)
if err != nil {
return err
}
@ -146,7 +187,7 @@ func (h *hub) Receive(partitionID string, handler Handler, opts ...ReceiveOption
// Send sends an AMQP message to the broker
func (h *hub) Send(ctx context.Context, message *amqp.Message, opts ...SendOption) error {
sender, err := h.getSender()
sender, err := h.getSender(ctx)
if err != nil {
return err
}
@ -168,7 +209,7 @@ func HubWithPartitionedSender(partitionID string) HubOption {
// HubWithOffsetPersistence configures the hub instance to read and write offsets so that if a hub is interrupted, it
// can resume after the last consumed event.
func HubWithOffsetPersistence(offsetPersister OffsetPersister) HubOption {
func HubWithOffsetPersistence(offsetPersister persist.OffsetPersister) HubOption {
return func(h *hub) error {
h.offsetPersister = offsetPersister
return nil
@ -195,12 +236,12 @@ func (h *hub) appendAgent(userAgent string) error {
return nil
}
func (h *hub) getSender() (*sender, error) {
func (h *hub) getSender(ctx context.Context) (*sender, error) {
h.senderMu.Lock()
defer h.senderMu.Unlock()
if h.sender == nil {
s, err := h.newSender()
s, err := h.newSender(ctx)
if err != nil {
return nil, err
}

Просмотреть файл

@ -3,28 +3,59 @@ package eventhub
import (
"context"
"fmt"
"github.com/Azure/azure-event-hubs-go/aad"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"math/rand"
"pack.ag/amqp"
"os"
"sync"
"testing"
"time"
"github.com/Azure/azure-event-hubs-go/aad"
"github.com/Azure/azure-event-hubs-go/sas"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"pack.ag/amqp"
)
func (suite *eventHubSuite) TestSasToken() {
tests := map[string]func(*testing.T, Client, []string, string){
"TestMultiSendAndReceive": testMultiSendAndReceive,
"TestHubRuntimeInformation": testHubRuntimeInformation,
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
}
for name, testFunc := range tests {
setupTestTeardown := func(t *testing.T) {
hubName := randomName("goehtest", 10)
mgmtHub, err := suite.ensureEventHub(context.Background(), hubName)
if err != nil {
t.Fatal(err)
}
defer suite.deleteEventHub(context.Background(), hubName)
provider, err := sas.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
}
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
if err := client.Close(); err != nil {
t.Fatal(err)
}
}
suite.T().Run(name, setupTestTeardown)
}
}
func (suite *eventHubSuite) TestPartitionedSender() {
tests := map[string]func(*testing.T, Client, string){
"TestSend": testBasicSend,
"TestSendAndReceive": testBasicSendAndReceive,
}
token, err := suite.getEventHubsTokenProvider()
if err != nil {
log.Fatal(err)
}
token.Refresh()
for name, testFunc := range tests {
setupTestTeardown := func(t *testing.T) {
hubName := randomName("goehtest", 10)
@ -34,7 +65,10 @@ func (suite *eventHubSuite) TestPartitionedSender() {
}
defer suite.deleteEventHub(context.Background(), hubName)
partitionID := (*mgmtHub.PartitionIds)[0]
provider := aad.NewProvider(token)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider, HubWithPartitionedSender(partitionID))
if err != nil {
t.Fatal(err)
@ -51,7 +85,7 @@ func (suite *eventHubSuite) TestPartitionedSender() {
}
func testBasicSend(t *testing.T, client Client, _ string) {
err := client.Send(context.Background(), amqp.NewMessage([]byte("Hello!")) )
err := client.Send(context.Background(), amqp.NewMessage([]byte("Hello!")))
assert.Nil(t, err)
}
@ -75,12 +109,14 @@ func testBasicSendAndReceive(t *testing.T, client Client, partitionID string) {
}
count := 0
err := client.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
err := client.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
assert.Equal(t, messages[count], string(msg.Data[0]))
count++
wg.Done()
return nil
}, ReceiveWithPrefetchCount(100))
cancel()
if err != nil {
t.Fatal(err)
}
@ -88,16 +124,10 @@ func testBasicSendAndReceive(t *testing.T, client Client, partitionID string) {
}
func (suite *eventHubSuite) TestMultiPartition() {
tests := map[string]func(*testing.T, Client, []string){
tests := map[string]func(*testing.T, Client, []string, string){
"TestMultiSendAndReceive": testMultiSendAndReceive,
}
token, err := suite.getEventHubsTokenProvider()
if err != nil {
log.Fatal(err)
}
token.Refresh()
for name, testFunc := range tests {
setupTestTeardown := func(t *testing.T) {
hubName := randomName("goehtest", 10)
@ -106,13 +136,16 @@ func (suite *eventHubSuite) TestMultiPartition() {
t.Fatal(err)
}
defer suite.deleteEventHub(context.Background(), hubName)
provider := aad.NewProvider(token)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
}
testFunc(t, client, *mgmtHub.PartitionIds)
testFunc(t, client, *mgmtHub.PartitionIds, hubName)
if err := client.Close(); err != nil {
t.Fatal(err)
}
@ -122,7 +155,7 @@ func (suite *eventHubSuite) TestMultiPartition() {
}
}
func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string) {
func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string, _ string) {
numMessages := rand.Intn(100) + 20
var wg sync.WaitGroup
wg.Add(numMessages)
@ -141,8 +174,9 @@ func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
for _, partitionID := range partitionIDs {
err := client.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
err := client.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
wg.Done()
return nil
}, ReceiveWithPrefetchCount(100))
@ -150,6 +184,7 @@ func testMultiSendAndReceive(t *testing.T, client Client, partitionIDs []string)
t.Fatal(err)
}
}
cancel()
wg.Wait()
}
@ -159,12 +194,6 @@ func (suite *eventHubSuite) TestHubManagement() {
"TestHubPartitionRuntimeInformation": testHubPartitionRuntimeInformation,
}
token, err := suite.getEventHubsTokenProvider()
if err != nil {
log.Fatal(err)
}
token.Refresh()
for name, testFunc := range tests {
setupTestTeardown := func(t *testing.T) {
hubName := randomName("goehtest", 10)
@ -173,7 +202,10 @@ func (suite *eventHubSuite) TestHubManagement() {
t.Fatal(err)
}
defer suite.deleteEventHub(context.Background(), hubName)
provider := aad.NewProvider(token)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
t.Fatal(err)
}
client, err := NewClient(suite.namespace, hubName, provider)
if err != nil {
t.Fatal(err)
@ -208,6 +240,13 @@ func testHubPartitionRuntimeInformation(t *testing.T, client Client, partitionID
assert.Equal(t, "-1", info.LastEnqueuedOffset) // brand new, so should be very last
}
func TestEnvironmentalCreation(t *testing.T) {
os.Setenv("EVENTHUB_NAME", "foo")
_, err := NewClientFromEnvironment()
assert.Nil(t, err)
os.Unsetenv("EVENTHUB_NAME")
}
func BenchmarkReceive(b *testing.B) {
suite := new(eventHubSuite)
suite.SetupSuite()
@ -225,8 +264,10 @@ func BenchmarkReceive(b *testing.B) {
messages[i] = randomName("hello", 10)
}
token, err := suite.getEventHubsTokenProvider()
provider := aad.NewProvider(token)
provider, err := aad.NewProviderFromEnvironment()
if err != nil {
log.Fatal(err)
}
hub, err := NewClient(suite.namespace, *mgmtHub.Name, provider)
if err != nil {
b.Fatal(err)
@ -248,9 +289,10 @@ func BenchmarkReceive(b *testing.B) {
b.ResetTimer()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// receive from all partition IDs
for _, partitionID := range *mgmtHub.PartitionIds {
err = hub.Receive(partitionID, func(ctx context.Context, msg *amqp.Message) error {
err = hub.Receive(ctx, partitionID, func(ctx context.Context, msg *amqp.Message) error {
wg.Done()
return nil
}, ReceiveWithPrefetchCount(100))
@ -258,7 +300,7 @@ func BenchmarkReceive(b *testing.B) {
b.Fatal(err)
}
}
cancel()
wg.Wait()
b.StopTimer()
}

Просмотреть файл

@ -3,12 +3,13 @@ package mgmt
import (
"context"
"fmt"
"time"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/rpc"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
"pack.ag/amqp"
"time"
)
const (
@ -103,7 +104,7 @@ func (c *Client) GetHubRuntimeInformation(ctx context.Context, conn *amqp.Client
entityNameKey: c.hubName,
},
}
err = c.addSecurityToken(msg)
msg, err = c.addSecurityToken(msg)
if err != nil {
return nil, err
}
@ -135,7 +136,7 @@ func (c *Client) GetHubPartitionRuntimeInformation(ctx context.Context, conn *am
partitionNameKey: partitionID,
},
}
err = c.addSecurityToken(msg)
msg, err = c.addSecurityToken(msg)
if err != nil {
return nil, err
}
@ -152,13 +153,16 @@ func (c *Client) GetHubPartitionRuntimeInformation(ctx context.Context, conn *am
return hubPartitionRuntimeInfo, nil
}
func (c *Client) addSecurityToken(msg *amqp.Message) error {
func (c *Client) addSecurityToken(msg *amqp.Message) (*amqp.Message, error) {
// TODO (devigned): need to uncomment this functionality after getting some guidance from the Event Hubs team (only works for SAS tokens right now)
//token, err := c.tokenProvider.GetToken(c.getTokenAudience())
//if err != nil {
// return nil
// return nil, err
//}
//msg.ApplicationProperties[securityTokenKey] = token.Token
return nil
return msg, nil
}
func (c *Client) getTokenAudience() string {
@ -228,10 +232,6 @@ func newHubRuntimeInformation(msg *amqp.Message) (*HubRuntimeInformation, error)
}
if createdAt, ok := values[resultCreatedAtKey].(time.Time); ok {
//t, err := time.Parse("UnixDate", createdAt)
//if err != nil {
// return nil, err
//}
info.CreatedAt = createdAt
} else {
return nil, errors.Errorf(errMsgFmt, resultCreatedAtKey, values)

Просмотреть файл

@ -1,14 +1,15 @@
package eventhub
import (
"fmt"
"context"
"runtime"
"sync"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/cbs"
"github.com/Azure/go-autorest/autorest/azure"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
"runtime"
"sync"
)
type (
@ -48,7 +49,8 @@ func (ns *namespace) connection() (*amqp.Client, error) {
amqp.ConnProperty("version", "0.0.1"),
amqp.ConnProperty("platform", runtime.GOOS),
amqp.ConnProperty("framework", runtime.Version()),
amqp.ConnProperty("user-agent", rootUserAgent))
amqp.ConnProperty("user-agent", rootUserAgent),
)
if err != nil {
return nil, err
}
@ -57,17 +59,17 @@ func (ns *namespace) connection() (*amqp.Client, error) {
return ns.client, nil
}
func (ns *namespace) negotiateClaim(entityPath string) error {
func (ns *namespace) negotiateClaim(ctx context.Context, entityPath string) error {
audience := ns.getEntityAudience(entityPath)
conn, err := ns.connection()
if err != nil {
return err
}
return cbs.NegotiateClaim(audience, conn, ns.tokenProvider)
return cbs.NegotiateClaim(ctx, audience, conn, ns.tokenProvider)
}
func (ns *namespace) getAmqpHostURI() string {
return fmt.Sprintf("amqps://%s.%s/", ns.name, ns.environment.ServiceBusEndpointSuffix)
return "amqps://" + ns.name + "." + ns.environment.ServiceBusEndpointSuffix + "/"
}
func (ns *namespace) getEntityAudience(entityPath string) string {

Просмотреть файл

@ -1,4 +1,4 @@
package eventhub
package persist
import (
"github.com/pkg/errors"
@ -7,6 +7,13 @@ import (
)
type (
// OffsetPersister provides persistence for the received offset for a given namespace, hub name, consumer group, partition Id and
// offset so that if a receiver where to be interrupted, it could resume after the last consumed event.
OffsetPersister interface {
Write(namespace, name, consumerGroup, partitionID, offset string) error
Read(namespace, name, consumerGroup, partitionID string) (string, error)
}
// MemoryPersister is a default implementation of a Hub OffsetPersister, which will persist offset information in
// memory.
MemoryPersister struct {

Просмотреть файл

@ -3,11 +3,11 @@ package eventhub
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"net"
"pack.ag/amqp"
"sync/atomic"
"time"
"github.com/Azure/azure-event-hubs-go/persist"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
)
const (
@ -36,7 +36,7 @@ type (
consumerGroup string
partitionID string
prefetchCount uint32
done chan struct{}
done func()
lastReceivedOffset atomic.Value
}
@ -53,8 +53,6 @@ func ReceiveWithConsumerGroup(consumerGroup string) ReceiveOption {
}
// ReceiveWithStartingOffset configures the receiver to start at a given position in the event stream
//
// This setting will be overridden by the Hub's OffsetPersister if an offset can be read.
func ReceiveWithStartingOffset(offset string) ReceiveOption {
return func(receiver *receiver) error {
receiver.storeLastReceivedOffset(offset)
@ -62,6 +60,14 @@ func ReceiveWithStartingOffset(offset string) ReceiveOption {
}
}
// ReceiveWithLatestOffset configures the receiver to start at a given position in the event stream
func ReceiveWithLatestOffset() ReceiveOption {
return func(receiver *receiver) error {
receiver.storeLastReceivedOffset(EndOfStream)
return nil
}
}
// ReceiveWithPrefetchCount configures the receiver to attempt to fetch as many messages as the prefetch amount
func ReceiveWithPrefetchCount(prefetch uint32) ReceiveOption {
return func(receiver *receiver) error {
@ -71,13 +77,12 @@ func ReceiveWithPrefetchCount(prefetch uint32) ReceiveOption {
}
// newReceiver creates a new Service Bus message listener given an AMQP client and an entity path
func (h *hub) newReceiver(partitionID string, opts ...ReceiveOption) (*receiver, error) {
func (h *hub) newReceiver(ctx context.Context, partitionID string, opts ...ReceiveOption) (*receiver, error) {
receiver := &receiver{
hub: h,
consumerGroup: DefaultConsumerGroup,
prefetchCount: defaultPrefetchCount,
partitionID: partitionID,
done: make(chan struct{}),
}
for _, opt := range opts {
@ -87,38 +92,33 @@ func (h *hub) newReceiver(partitionID string, opts ...ReceiveOption) (*receiver,
}
log.Debugf("creating a new receiver for entity path %s", receiver.getAddress())
err := receiver.newSessionAndLink()
if err != nil {
return nil, err
}
return receiver, nil
err := receiver.newSessionAndLink(ctx)
return receiver, err
}
// Close will close the AMQP session and link of the receiver
func (r *receiver) Close() error {
close(r.done)
if r.done != nil {
r.done()
}
err := r.receiver.Close()
if err != nil {
_ = r.session.Close()
return err
}
err = r.session.Close()
if err != nil {
return err
}
return nil
return r.session.Close()
}
// Recover will attempt to close the current session and link, then rebuild them
func (r *receiver) Recover() error {
func (r *receiver) Recover(ctx context.Context) error {
err := r.Close()
if err != nil {
return err
}
err = r.newSessionAndLink()
err = r.newSessionAndLink(ctx)
if err != nil {
return err
}
@ -128,20 +128,20 @@ func (r *receiver) Recover() error {
// Listen start a listener for messages sent to the entity path
func (r *receiver) Listen(handler Handler) {
ctx, done := context.WithCancel(context.Background())
r.done = done
messages := make(chan *amqp.Message)
go r.listenForMessages(messages)
go r.handleMessages(messages, handler)
go r.listenForMessages(ctx, messages)
go r.handleMessages(ctx, messages, handler)
}
func (r *receiver) handleMessages(messages chan *amqp.Message, handler Handler) {
func (r *receiver) handleMessages(ctx context.Context, messages chan *amqp.Message, handler Handler) {
for {
select {
case <-r.done:
case <-ctx.Done():
log.Debug("done handling messages")
close(messages)
return
case msg := <-messages:
ctx := context.Background()
id := messageID(msg)
log.Debugf("message id: %v is being passed to handler", id)
@ -149,48 +149,39 @@ func (r *receiver) handleMessages(messages chan *amqp.Message, handler Handler)
if err != nil {
msg.Reject()
log.Debugf("message rejected: id: %v", id)
} else {
// Accept message
msg.Accept()
log.Debugf("message accepted: id: %v", id)
continue
}
msg.Accept()
log.Debugf("message accepted: id: %v", id)
}
}
}
func (r *receiver) listenForMessages(msgChan chan *amqp.Message) {
func (r *receiver) listenForMessages(ctx context.Context, msgChan chan *amqp.Message) {
for {
select {
case <-r.done:
log.Debug("done listening for messages")
return
default:
waitCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
msg, err := r.receiver.Receive(waitCtx)
cancel()
if err == amqp.ErrLinkClosed || err == amqp.ErrSessionClosed {
log.Debug("done listening for messages due to link or session closed")
time.Sleep(10 * time.Second)
return
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
log.Debug("attempting to receive messages timed out")
continue
} else if err != nil {
log.Error(err)
msg, err := r.receiver.Receive(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
log.Error(err)
return
}
r.receivedMessage(msg)
msgChan <- msg
id := messageID(msg)
log.Debugf("Message received: %s", id)
select {
case msgChan <- msg:
case <-ctx.Done():
return
}
}
}
// newSessionAndLink will replace the session and link on the receiver
func (r *receiver) newSessionAndLink() error {
func (r *receiver) newSessionAndLink(ctx context.Context) error {
address := r.getAddress()
err := r.hub.namespace.negotiateClaim(address)
err := r.hub.namespace.negotiateClaim(ctx, address)
if err != nil {
return err
}
@ -258,7 +249,7 @@ func (r *receiver) hubName() string {
return r.hub.name
}
func (r *receiver) offsetPersister() OffsetPersister {
func (r *receiver) offsetPersister() persist.OffsetPersister {
return r.hub.offsetPersister
}
@ -280,7 +271,7 @@ func (r *receiver) receivedMessage(msg *amqp.Message) {
}
func messageID(msg *amqp.Message) interface{} {
id := interface{}("null")
var id interface{} = "null"
if msg.Properties != nil {
id = msg.Properties.MessageID
}

Просмотреть файл

@ -3,14 +3,15 @@ package rpc
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/Azure/azure-event-hubs-go/common"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
"strings"
"sync"
"time"
)
const (
@ -56,7 +57,8 @@ func NewLink(conn *amqp.Client, address string) (*Link, error) {
clientAddress := strings.Replace("$", "", address, -1) + replyPostfix + id
authReceiver, err := authSession.NewReceiver(
amqp.LinkSourceAddress(address),
amqp.LinkTargetAddress(clientAddress))
amqp.LinkTargetAddress(clientAddress),
)
if err != nil {
return nil, err
}
@ -79,26 +81,19 @@ func (l *Link) RetryableRPC(ctx context.Context, times int, delay time.Duration,
return nil, err
}
if res.ServerError() {
switch {
case res.Code >= 200 && res.Code < 300:
log.Debugf("successful rpc on link %s: status code %d and description: %s", l.id, res.Code, res.Description)
return res, nil
case res.Code >= 500:
errMessage := fmt.Sprintf("server error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
log.Debugln(errMessage)
return nil, &common.Retryable{Message: errMessage}
}
if res.ClientError() {
errMessage := fmt.Sprintf("client error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
default:
errMessage := fmt.Sprintf("unhandled error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
log.Debugln(errMessage)
return nil, errors.New(errMessage)
return nil, &common.Retryable{Message: errMessage}
}
if res.Success() {
log.Debugf("successful rpc on link %s: status code %d and description: %s", l.id, res.Code, res.Description)
return res, nil
}
errMessage := fmt.Sprintf("unhandled error link %s: status code %d and description: %s", l.id, res.Code, res.Description)
log.Debugln(errMessage)
return nil, &common.Retryable{Message: errMessage}
})
if err != nil {
return nil, err
@ -145,32 +140,39 @@ func (l *Link) RPC(ctx context.Context, msg *amqp.Message) (*Response, error) {
return response, err
}
// Close the link sender, receiver and session
func (l *Link) Close() {
if l.sender != nil {
l.sender.Close()
// Close the link receiver, sender and session
func (l *Link) Close() error {
if err := l.closeReceiver(); err != nil {
_ = l.closeSender()
_ = l.closeSession()
return err
}
if err := l.closeSender(); err != nil {
_ = l.closeSession()
return err
}
return l.closeSession()
}
func (l *Link) closeReceiver() error {
if l.receiver != nil {
l.receiver.Close()
return l.receiver.Close()
}
return nil
}
func (l *Link) closeSender() error {
if l.sender != nil {
return l.sender.Close()
}
return nil
}
func (l *Link) closeSession() error {
if l.session != nil {
l.session.Close()
return l.session.Close()
}
}
// Success return true if the status code is between 200 and 300
func (r *Response) Success() bool {
return r.Code >= 200 && r.Code < 300
}
// ServerError is true when status code is 500 or greater
func (r *Response) ServerError() bool {
return r.Code >= 500
}
// ClientError is true when status code is in the 400s
func (r *Response) ClientError() bool {
return r.Code >= 400 && r.Code < 500
return nil
}

Просмотреть файл

@ -5,11 +5,15 @@ import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/Azure/azure-event-hubs-go/auth"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/Azure/azure-event-hubs-go/auth"
"github.com/Azure/azure-event-hubs-go/common"
"github.com/pkg/errors"
)
type (
@ -33,6 +37,36 @@ func NewProvider(namespace, keyName, key string) auth.TokenProvider {
}
}
// NewProviderFromEnvironment creates a new SAS TokenProvider from environment variables
//
// Expected Environment Variables:
// - "EVENTHUB_NAMESPACE" the namespace of the Event Hub instance
// - "EVENTHUB_KEY_NAME" the name of the Event Hub key
// - "EVENTHUB_KEY_VALUE" the secret for the Event Hub key named in "EVENTHUB_KEY_NAME"
func NewProviderFromEnvironment() (auth.TokenProvider, error) {
var provider auth.TokenProvider
keyName := os.Getenv("EVENTHUB_KEY_NAME")
keyValue := os.Getenv("EVENTHUB_KEY_VALUE")
namespace := os.Getenv("EVENTHUB_NAMESPACE")
connStr := os.Getenv("EVENTHUB_CONNECTION_STRING")
if (keyName == "" || keyValue == "" || namespace == "") && connStr == "" {
return nil, errors.New("unable to build SAS token provider because (EVENTHUB_KEY_NAME, EVENTHUB_KEY_VALUE and EVENTHUB_NAMESPACE) were empty, and EVENTHUB_CONNECTION_STRING was empty")
}
if connStr != "" {
parsed, err := common.ParsedConnectionFromStr(connStr)
if err != nil {
return nil, err
}
provider = NewProvider(parsed.Namespace, parsed.KeyName, parsed.Key)
} else {
provider = NewProvider(namespace, keyName, keyValue)
}
return provider, nil
}
// GetToken gets a CBS SAS token
func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) {
signed, expiry := t.signer.SignWithDuration(audience, 2*time.Hour)
@ -56,10 +90,10 @@ func (s *Signer) SignWithDuration(uri string, interval time.Duration) (signed, e
// SignWithExpiry signs a given uri with a given expiry string
func (s *Signer) SignWithExpiry(uri, expiry string) string {
u := strings.ToLower(url.QueryEscape(uri))
sts := stringToSign(u, expiry)
audience := strings.ToLower(url.QueryEscape(uri))
sts := stringToSign(audience, expiry)
sig := s.signString(sts)
return fmt.Sprintf("SharedAccessSignature sig=%s&se=%s&skn=%s&sr=%s", sig, expiry, s.keyName, u)
return fmt.Sprintf("SharedAccessSignature sr=%s&sig=%s&se=%s&skn=%s", audience, sig, expiry, s.keyName)
}
func signatureExpiry(from time.Time, interval time.Duration) string {

Просмотреть файл

@ -3,6 +3,7 @@ package eventhub
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"pack.ag/amqp"
)
@ -22,48 +23,33 @@ type (
)
// newSender creates a new Service Bus message sender given an AMQP client and entity path
func (h *hub) newSender() (*sender, error) {
func (h *hub) newSender(ctx context.Context) (*sender, error) {
s := &sender{
hub: h,
partitionID: h.senderPartitionID,
}
log.Debugf("creating a new sender for entity path %s", s.getAddress())
err := s.newSessionAndLink()
if err != nil {
return nil, err
}
return s, nil
err := s.newSessionAndLink(ctx)
return s, err
}
// Recover will attempt to close the current session and link, then rebuild them
func (s *sender) Recover() error {
func (s *sender) Recover(ctx context.Context) error {
err := s.Close()
if err != nil {
return err
}
err = s.newSessionAndLink()
if err != nil {
return err
}
return nil
return s.newSessionAndLink(ctx)
}
// Close will close the AMQP session and link of the sender
func (s *sender) Close() error {
err := s.sender.Close()
if err != nil {
_ = s.session.Close()
return err
}
err = s.session.Close()
if err != nil {
return err
}
return nil
return s.session.Close()
}
// Send will send a message to the entity path with options
@ -82,11 +68,7 @@ func (s *sender) Send(ctx context.Context, msg *amqp.Message, opts ...SendOption
msg.Annotations["x-opt-partition-key"] = s.partitionID
}
err := s.sender.Send(ctx, msg)
if err != nil {
return err
}
return nil
return s.sender.Send(ctx, msg)
}
//func (s *sender) SendBatch(ctx context.Context, messages []*amqp.Message) error {
@ -115,8 +97,8 @@ func (s *sender) prepareMessage(msg *amqp.Message) {
}
// newSessionAndLink will replace the existing session and link
func (s *sender) newSessionAndLink() error {
err := s.hub.namespace.negotiateClaim(s.getAddress())
func (s *sender) newSessionAndLink(ctx context.Context) error {
err := s.hub.namespace.negotiateClaim(ctx, s.getAddress())
if err != nil {
return err
}

Просмотреть файл

@ -1,9 +1,10 @@
package eventhub
import (
"sync/atomic"
"github.com/satori/go.uuid"
"pack.ag/amqp"
"sync/atomic"
)
type (
@ -20,7 +21,6 @@ func newSession(amqpSession *amqp.Session) *session {
return &session{
Session: amqpSession,
SessionID: uuid.NewV4().String(),
counter: 0,
}
}

Просмотреть файл

@ -4,18 +4,18 @@ import (
"context"
"errors"
"flag"
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
rm "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"math/rand"
"net/http"
"os"
"testing"
"time"
mgmt "github.com/Azure/azure-sdk-for-go/services/eventhub/mgmt/2017-04-01/eventhub"
rm "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources"
"github.com/Azure/go-autorest/autorest/azure"
azauth "github.com/Azure/go-autorest/autorest/azure/auth"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
)
var (
@ -32,19 +32,9 @@ type (
// eventHubSuite encapsulates a end to end test of Event Hubs with build up and tear down of all EH resources
eventHubSuite struct {
suite.Suite
tenantID string
subscriptionID string
clientID string
clientSecret string
namespace string
env azure.Environment
armToken *adal.ServicePrincipalToken
}
servicePrincipalCredentials struct {
TenantID string
ApplicationID string
Secret string
}
// HubMgmtOption represents an option for configuring an Event Hub.
@ -67,13 +57,20 @@ func (suite *eventHubSuite) SetupSuite() {
log.SetLevel(log.DebugLevel)
}
suite.tenantID = mustGetEnv("AZURE_TENANT_ID")
suite.subscriptionID = mustGetEnv("AZURE_SUBSCRIPTION_ID")
suite.clientID = mustGetEnv("AZURE_CLIENT_ID")
suite.clientSecret = mustGetEnv("AZURE_CLIENT_SECRET")
suite.namespace = mustGetEnv("EVENTHUB_NAMESPACE")
suite.env = azure.PublicCloud
suite.armToken = suite.servicePrincipalToken()
envName := os.Getenv("AZURE_ENVIRONMENT")
if envName == "" {
suite.env = azure.PublicCloud
} else {
var err error
env, err := azure.EnvironmentFromName(envName)
if err != nil {
log.Fatal(err)
}
suite.env = env
}
err := suite.ensureProvisioned(mgmt.SkuTierStandard)
if err != nil {
@ -86,7 +83,7 @@ func (suite *eventHubSuite) TearDownSuite() {
}
func (suite *eventHubSuite) ensureProvisioned(tier mgmt.SkuTier) error {
_, err := ensureResourceGroup(context.Background(), suite.subscriptionID, ResourceGroupName, Location, suite.armToken, suite.env)
_, err := ensureResourceGroup(context.Background(), suite.subscriptionID, ResourceGroupName, Location, suite.env)
if err != nil {
return err
}
@ -99,27 +96,9 @@ func (suite *eventHubSuite) ensureProvisioned(tier mgmt.SkuTier) error {
return nil
}
func (suite *eventHubSuite) servicePrincipalToken() *adal.ServicePrincipalToken {
oauthConfig, err := adal.NewOAuthConfig(suite.env.ActiveDirectoryEndpoint, suite.tenantID)
if err != nil {
log.Fatalln(err)
}
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig,
suite.clientID,
suite.clientSecret,
suite.env.ResourceManagerEndpoint)
if err != nil {
log.Fatalln(err)
}
return tokenProvider
}
// ensureResourceGroup creates a Azure Resource Group if it does not already exist
func ensureResourceGroup(ctx context.Context, subscriptionID, name, location string, armToken *adal.ServicePrincipalToken, env azure.Environment) (*rm.Group, error) {
groupClient := getRmGroupClientWithToken(subscriptionID, armToken, env)
func ensureResourceGroup(ctx context.Context, subscriptionID, name, location string, env azure.Environment) (*rm.Group, error) {
groupClient := getRmGroupClientWithToken(subscriptionID, env)
group, err := groupClient.Get(ctx, name)
if group.StatusCode == http.StatusNotFound {
@ -135,13 +114,13 @@ func ensureResourceGroup(ctx context.Context, subscriptionID, name, location str
}
// ensureNamespace creates a Azure Event Hub Namespace if it does not already exist
func ensureNamespace(ctx context.Context, subscriptionID, rg, name, location string, armToken *adal.ServicePrincipalToken, env azure.Environment, opts ...namespaceMgmtOption) (*mgmt.EHNamespace, error) {
_, err := ensureResourceGroup(ctx, subscriptionID, rg, location, armToken, env)
func ensureNamespace(ctx context.Context, subscriptionID, rg, name, location string, env azure.Environment, opts ...namespaceMgmtOption) (*mgmt.EHNamespace, error) {
_, err := ensureResourceGroup(ctx, subscriptionID, rg, location, env)
if err != nil {
return nil, err
}
client := getNamespaceMgmtClientWithToken(subscriptionID, armToken, env)
client := getNamespaceMgmtClientWithToken(subscriptionID, env)
namespace, err := client.Get(ctx, rg, name)
if err != nil {
return nil, err
@ -237,38 +216,54 @@ func (suite *eventHubSuite) deleteEventHub(ctx context.Context, name string) err
func (suite *eventHubSuite) getEventHubMgmtClient() *mgmt.EventHubsClient {
client := mgmt.NewEventHubsClientWithBaseURI(suite.env.ResourceManagerEndpoint, suite.subscriptionID)
client.Authorizer = autorest.NewBearerAuthorizer(suite.armToken)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
client.Authorizer = a
return &client
}
func (suite *eventHubSuite) getNamespaceMgmtClient() *mgmt.NamespacesClient {
return getNamespaceMgmtClientWithToken(suite.subscriptionID, suite.armToken, suite.env)
return getNamespaceMgmtClientWithToken(suite.subscriptionID, suite.env)
}
func getNamespaceMgmtClientWithToken(subscriptionID string, armToken *adal.ServicePrincipalToken, env azure.Environment) *mgmt.NamespacesClient {
func getNamespaceMgmtClientWithToken(subscriptionID string, env azure.Environment) *mgmt.NamespacesClient {
client := mgmt.NewNamespacesClientWithBaseURI(env.ResourceManagerEndpoint, subscriptionID)
client.Authorizer = autorest.NewBearerAuthorizer(armToken)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
client.Authorizer = a
return &client
}
func (suite *eventHubSuite) getNamespaceMgmtClientWithCredentials(ctx context.Context, subscriptionID, rg, name string) *mgmt.NamespacesClient {
client := mgmt.NewNamespacesClientWithBaseURI(suite.env.ResourceManagerEndpoint, suite.subscriptionID)
client.Authorizer = autorest.NewBearerAuthorizer(suite.armToken)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
client.Authorizer = a
return &client
}
func (suite *eventHubSuite) getRmGroupClient() *rm.GroupsClient {
return getRmGroupClientWithToken(suite.subscriptionID, suite.armToken, suite.env)
return getRmGroupClientWithToken(suite.subscriptionID, suite.env)
}
func getRmGroupClientWithToken(subscriptionID string, armToken *adal.ServicePrincipalToken, env azure.Environment) *rm.GroupsClient {
func getRmGroupClientWithToken(subscriptionID string, env azure.Environment) *rm.GroupsClient {
groupsClient := rm.NewGroupsClientWithBaseURI(env.ResourceManagerEndpoint, subscriptionID)
groupsClient.Authorizer = autorest.NewBearerAuthorizer(armToken)
a, err := azauth.NewAuthorizerFromEnvironment()
if err != nil {
log.Fatal(err)
}
groupsClient.Authorizer = a
return &groupsClient
}
func (suite *eventHubSuite) ensureResourceGroup() (*rm.Group, error) {
group, err := ensureResourceGroup(context.Background(), suite.subscriptionID, suite.namespace, Location, suite.armToken, suite.env)
group, err := ensureResourceGroup(context.Background(), suite.subscriptionID, suite.namespace, Location, suite.env)
if err != nil {
return nil, err
}
@ -276,37 +271,13 @@ func (suite *eventHubSuite) ensureResourceGroup() (*rm.Group, error) {
}
func (suite *eventHubSuite) ensureNamespace() (*mgmt.EHNamespace, error) {
ns, err := ensureNamespace(context.Background(), suite.subscriptionID, ResourceGroupName, suite.namespace, Location, suite.armToken, suite.env)
ns, err := ensureNamespace(context.Background(), suite.subscriptionID, ResourceGroupName, suite.namespace, Location, suite.env)
if err != nil {
return nil, err
}
return ns, err
}
func (suite *eventHubSuite) getEventHubsTokenProvider() (*adal.ServicePrincipalToken, error) {
// TODO: fix the azure environment var for the SB endpoint and EH endpoint
return suite.getTokenProvider("https://eventhubs.azure.net/")
}
func (suite *eventHubSuite) getTokenProvider(resourceURI string) (*adal.ServicePrincipalToken, error) {
oauthConfig, err := adal.NewOAuthConfig(suite.env.ActiveDirectoryEndpoint, suite.tenantID)
if err != nil {
log.Fatalln(err)
}
tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, suite.clientID, suite.clientSecret, resourceURI)
if err != nil {
return nil, err
}
err = tokenProvider.Refresh()
if err != nil {
return nil, err
}
return tokenProvider, nil
}
func mustGetEnv(key string) string {
v := os.Getenv(key)
if v == "" {