diff --git a/cbs.go b/cbs.go index e08ba60..ee31a12 100644 --- a/cbs.go +++ b/cbs.go @@ -51,7 +51,8 @@ func (sb *serviceBus) newCbsLink() (*cbsLink, error) { cbsClientAddress := cbsReplyToPrefix + sb.name.String() authReceiver, err := authSession.NewReceiver( amqp.LinkSourceAddress(cbsAddress), - amqp.LinkTargetAddress(cbsClientAddress)) + amqp.LinkTargetAddress(cbsClientAddress), + ) if err != nil { return nil, err } @@ -68,14 +69,13 @@ func (sb *serviceBus) ensureCbsLink() error { sb.cbsMu.Lock() defer sb.cbsMu.Unlock() - if sb.cbsLink == nil { - link, err := sb.newCbsLink() - if err != nil { - return err - } - sb.cbsLink = link + if sb.cbsLink != nil { + return nil } - return nil + + link, err := sb.newCbsLink() + sb.cbsLink = link + return err } func (sb *serviceBus) negotiateClaim(entityPath string) error { @@ -103,31 +103,40 @@ func (sb *serviceBus) negotiateClaim(entityPath string) error { _, err = retry(3, 1*time.Second, func() (interface{}, error) { log.Debugf("Attempting to negotiate cbs for %s in namespace %s", entityPath, sb.namespace) - err := sb.cbsLink.send(context.Background(), msg) + + ctx := context.Background() + + err := sb.cbsLink.send(ctx, msg) if err != nil { return nil, err } - res, err := sb.cbsLink.receive(context.Background()) + res, err := sb.cbsLink.receive(ctx) if err != nil { return nil, err } - if statusCode, ok := res.ApplicationProperties[cbsStatusCodeKey].(int32); ok { - description := res.ApplicationProperties[cbsDescriptionKey].(string) - if statusCode >= 200 && statusCode < 300 { - log.Debugf("Successfully negotiated cbs for %s in namespace %s", entityPath, sb.namespace) - return res, nil - } else if statusCode >= 500 { - log.Debugf("Re-negotiating cbs for %s in namespace %s. Received status code: %d and error: %s", entityPath, sb.namespace, statusCode, description) - return nil, &retryable{message: "cbs error: " + description} - } else { - log.Debugf("Failed negotiating cbs for %s in namespace %s with error %d", entityPath, sb.namespace, statusCode) - return nil, fmt.Errorf("cbs error: failed with code %d and message: %s", statusCode, description) - } + statusCode, ok := res.ApplicationProperties[cbsStatusCodeKey].(int32) + if !ok { + return nil, &retryable{message: "cbs error: didn't understand the replied message status code"} } - return nil, &retryable{message: "cbs error: didn't understand the replied message status code"} + description, ok := res.ApplicationProperties[cbsDescriptionKey].(string) + if !ok { + return nil, &retryable{message: "cbs error: didn't understand the replied message description"} + } + + switch { + case statusCode >= 200 && statusCode < 300: + log.Debugf("Successfully negotiated cbs for %s in namespace %s", entityPath, sb.namespace) + return res, nil + case statusCode >= 500: + log.Debugf("Re-negotiating cbs for %s in namespace %s. Received status code: %d and error: %s", entityPath, sb.namespace, statusCode, description) + return nil, &retryable{message: "cbs error: " + description} + default: + log.Debugf("Failed negotiating cbs for %s in namespace %s with error %d", entityPath, sb.namespace, statusCode) + return nil, fmt.Errorf("cbs error: failed with code %d and message: %s", statusCode, description) + } }) return err diff --git a/helpers.go b/helpers.go index 710b60a..4fc13d3 100644 --- a/helpers.go +++ b/helpers.go @@ -47,10 +47,8 @@ func parseAzureResourceID(id string) (*resourceID, error) { } path := idURL.Path - path = strings.TrimSpace(path) path = strings.Trim(path, "/") - components := strings.Split(path, "/") // We should have an even number of key-value pairs. @@ -58,57 +56,45 @@ func parseAzureResourceID(id string) (*resourceID, error) { return nil, fmt.Errorf("the number of path segments is not divisible by 2 in %q", path) } - var subscriptionID string + idObj := &resourceID{ + Path: make(map[string]string, len(components)/2), + } // Put the constituent key-value pairs into a map - componentMap := make(map[string]string, len(components)/2) for current := 0; current < len(components); current += 2 { key := components[current] value := components[current+1] + switch { // Check key/value for empty strings. - if key == "" || value == "" { + case key == "" || value == "": return nil, fmt.Errorf("key/value cannot be empty strings. Key: '%s', Value: '%s'", key, value) - } // Catch the subscriptionID before it can be overwritten by another "subscriptions" // value in the ID which is the case for the Service Bus subscription resource - if key == "subscriptions" && subscriptionID == "" { - subscriptionID = value - } else { - componentMap[key] = value - } - } + case idObj.SubscriptionID == "" && key == "subscriptions": + idObj.SubscriptionID = value - // Build up a ResourceID from the map - idObj := &resourceID{} - idObj.Path = componentMap - - if subscriptionID != "" { - idObj.SubscriptionID = subscriptionID - } else { - return nil, fmt.Errorf("no subscription ID found in: %q", path) - } - - if resourceGroup, ok := componentMap["resourceGroups"]; ok { - idObj.ResourceGroup = resourceGroup - delete(componentMap, "resourceGroups") - } else { // Some Azure APIs are weird and provide things in lower case... // However it's not clear whether the casing of other elements in the URI // matter, so we explicitly look for that case here. - if resourceGroup, ok := componentMap["resourcegroups"]; ok { - idObj.ResourceGroup = resourceGroup - delete(componentMap, "resourcegroups") - } else { - return nil, fmt.Errorf("no resource group name found in: %q", path) + case strings.EqualFold(key, "resourceGroups"): + idObj.ResourceGroup = value + + case key == "providers": + idObj.Provider = value + + default: + idObj.Path[key] = value } } // It is OK not to have a provider in the case of a resource group - if provider, ok := componentMap["providers"]; ok { - idObj.Provider = provider - delete(componentMap, "providers") + switch { + case idObj.SubscriptionID == "": + return nil, fmt.Errorf("no subscription ID found in: %q", path) + case idObj.ResourceGroup == "": + return nil, fmt.Errorf("no resource group name found in: %q", path) } return idObj, nil diff --git a/queue.go b/queue.go index f7e9ff2..2f9912a 100644 --- a/queue.go +++ b/queue.go @@ -18,7 +18,7 @@ type ( /* QueueWithPartitioning ensure the created queue will be a partitioned queue. Partitioned queues offer increased storage and availability compared to non-partitioned queues with the trade-off of requiring the following to ensure -FIFO message retreival: +FIFO message retrieval: SessionId. If a message has the SessionId property set, then Service Bus uses the SessionId property as the partition key. This way, all messages that belong to the same session are assigned to the same fragment and handled diff --git a/receiver.go b/receiver.go index 2cabc72..67e398f 100644 --- a/receiver.go +++ b/receiver.go @@ -2,8 +2,6 @@ package servicebus import ( "context" - "net" - "time" "github.com/satori/go.uuid" log "github.com/sirupsen/logrus" @@ -17,7 +15,7 @@ type ( session *session receiver *amqp.Receiver entityPath string - done chan struct{} + done func() Name uuid.UUID } ) @@ -27,21 +25,21 @@ func (sb *serviceBus) newReceiver(entityPath string) (*receiver, error) { receiver := &receiver{ sb: sb, entityPath: entityPath, - done: make(chan struct{}), } err := receiver.newSessionAndLink() - if err != nil { - return nil, err - } - return receiver, nil + return receiver, err } // Close will close the AMQP session and link of the receiver func (r *receiver) Close() error { - close(r.done) - + // This isn't safe to be called concurrently with Listen + if r.done != nil { + r.done() + } err := r.receiver.Close() if err != nil { + // ensure session is closed on receiver error + _ = r.session.Close() return err } @@ -60,67 +58,67 @@ func (r *receiver) Recover() error { // Listen start a listener for messages sent to the entity path func (r *receiver) Listen(handler Handler) { + ctx, done := context.WithCancel(context.Background()) + r.done = done messages := make(chan *amqp.Message) - go r.listenForMessages(messages) - go r.handleMessages(messages, handler) + go r.listenForMessages(ctx, messages) + go r.handleMessages(ctx, messages, handler) } -func (r *receiver) handleMessages(messages chan *amqp.Message, handler Handler) { +func (r *receiver) handleMessages(ctx context.Context, messages chan *amqp.Message, handler Handler) { for { select { - case <-r.done: + case <-ctx.Done(): log.Debug("done handling messages") - close(messages) return case msg := <-messages: - ctx := context.Background() - id := interface{}("null") + var id interface{} = "null" if msg.Properties != nil { id = msg.Properties.MessageID } + log.Debugf("Message id: %s is being passed to handler", id) err := handler(ctx, msg) - if err != nil { msg.Reject() log.Debugf("Message rejected: id: %s", id) - } else { - // Accept message - msg.Accept() - log.Debugf("Message accepted: id: %s", id) + continue } + + // Accept message + msg.Accept() + log.Debugf("Message accepted: id: %s", id) } } } -func (r *receiver) listenForMessages(msgChan chan *amqp.Message) { +func (r *receiver) listenForMessages(ctx context.Context, msgChan chan *amqp.Message) { for { - select { - case <-r.done: - log.Debug("done listenting for messages") - return - default: - //log.Debug("attempting to receive messages") - waitCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - msg, err := r.receiver.Receive(waitCtx) - cancel() + //log.Debug("attempting to receive messages") + msg, err := r.receiver.Receive(ctx) + // TODO (vcabbage): This previously checked `net.Error.Timeout() == true`, which + // should never happen. If it does it's a bug in pack.ag/amqp. + if err != nil { + if ctx.Err() != nil { + return + } - // TODO: handle receive errors better. It's not sufficient to check only for timeout - if err, ok := err.(net.Error); ok && err.Timeout() { - log.Debug("attempting to receive messages timed out") - continue - } else if err != nil { - log.Fatalln(err) - time.Sleep(10 * time.Second) - } - if msg != nil { - id := interface{}("null") - if msg.Properties != nil { - id = msg.Properties.MessageID - } - log.Debugf("Message received: %s", id) - msgChan <- msg - } + // TODO (vcabbage): I'm not sure what the appropriate action is here, this was + // previously a call to `log.Fatalln`, which calls os.Exit(1). + log.Error(err) + return + } + + var id interface{} = "null" + if msg.Properties != nil { + id = msg.Properties.MessageID + } + log.Debugf("Message received: %s", id) + + select { + case msgChan <- msg: + case <-ctx.Done(): + return } } } @@ -141,7 +139,8 @@ func (r *receiver) newSessionAndLink() error { amqpReceiver, err := amqpSession.NewReceiver( amqp.LinkSourceAddress(r.entityPath), - amqp.LinkCredit(10)) + amqp.LinkCredit(10), + ) if err != nil { return err } diff --git a/sender.go b/sender.go index 515e901..f49e691 100644 --- a/sender.go +++ b/sender.go @@ -27,11 +27,7 @@ func (sb *serviceBus) newSender(entityPath string) (*sender, error) { log.Debugf("creating a new sender for entity path %s", entityPath) err := s.newSessionAndLink() - if err != nil { - return nil, err - } - - return s, nil + return s, err } // Recover will attempt to close the current session and link, then rebuild them @@ -48,6 +44,7 @@ func (s *sender) Recover() error { func (s *sender) Close() error { err := s.sender.Close() if err != nil { + _ = s.session.Close() return err } diff --git a/servicebus.go b/servicebus.go index 6681216..caf31cb 100644 --- a/servicebus.go +++ b/servicebus.go @@ -3,7 +3,6 @@ package servicebus import ( "context" "errors" - "fmt" "regexp" "sync" @@ -165,7 +164,7 @@ func (sb *serviceBus) Close() error { return nil } -// Listen subscribes for messages sent to the provided entityPath. +// Receive subscribes for messages sent to the provided entityPath. func (sb *serviceBus) Receive(entityPath string, handler Handler) error { sb.receiverMu.Lock() defer sb.receiverMu.Unlock() @@ -194,6 +193,8 @@ func (sb *serviceBus) connection() (*amqp.Client, error) { sb.clientMu.Lock() defer sb.clientMu.Unlock() + // TODO (vcabbage): this will return nil, nil if sb.claimsBasedSecurityEnabled() == false. + // eventually leading to a panic when the consuming code calls conn.NewSession() if sb.client == nil && sb.claimsBasedSecurityEnabled() { host := getHostName(sb.namespace) client, err := amqp.Dial(host, amqp.ConnSASLAnonymous(), amqp.ConnMaxSessions(65535)) @@ -266,11 +267,7 @@ func newWithConnectionString(connStr string) (*serviceBus, error) { // parsedConnectionFromStr takes a string connection string from the Azure portal and returns the parsed representation. func parsedConnectionFromStr(connStr string) (*parsedConn, error) { matches := connStrRegex.FindStringSubmatch(connStr) - parsed, err := newParsedConnection(matches[1], matches[2], matches[3]) - if err != nil { - return nil, err - } - return parsed, nil + return newParsedConnection(matches[1], matches[2], matches[3]) } // newParsedConnection is a constructor for a parsedConn and verifies each of the inputs is non-null. @@ -286,7 +283,7 @@ func newParsedConnection(host string, keyName string, key string) (*parsedConn, } func getHostName(namespace string) string { - return fmt.Sprintf("amqps://%s.%s", namespace, "servicebus.windows.net") + return "amqps://" + namespace + ".servicebus.windows.net" } // claimsBasedSecurityEnabled indicates that the connection will use AAD JWT RBAC to authenticate in connections @@ -305,7 +302,7 @@ func getServiceBusTokenProvider(credential ServicePrincipalCredentials, env azur func getTokenProvider(resourceURI string, cred ServicePrincipalCredentials, env azure.Environment) (*adal.ServicePrincipalToken, error) { oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, cred.TenantID) if err != nil { - log.Fatalln(err) + return nil, err } tokenProvider, err := adal.NewServicePrincipalToken(*oauthConfig, cred.ApplicationID, cred.Secret, resourceURI) @@ -327,6 +324,7 @@ func (sb *serviceBus) drainReceivers() error { defer sb.receiverMu.Unlock() for _, receiver := range sb.receivers { + // TODO (vcabbage): what if an error occurs here? receiver.Close() } sb.receivers = []*receiver{} diff --git a/session.go b/session.go index bde809a..2da3568 100644 --- a/session.go +++ b/session.go @@ -21,7 +21,6 @@ func newSession(amqpSession *amqp.Session) *session { return &session{ Session: amqpSession, SessionID: uuid.NewV4(), - counter: 0, } }