From c38cea319c9d7c9fd30256bad8bbebce02e55478 Mon Sep 17 00:00:00 2001 From: David Justice Date: Wed, 21 Mar 2018 13:51:34 -0700 Subject: [PATCH] initial commit for common amqp deps --- .gitignore | 297 ++---------------------------------------- Gopkg.lock | 84 ++++++++++++ Gopkg.toml | 7 + Makefile | 98 ++++++++++++++ aad/jwt.go | 233 +++++++++++++++++++++++++++++++++ auth/token.go | 36 +++++ cbs/cbs.go | 56 ++++++++ conn/conn.go | 47 +++++++ conn/conn_test.go | 25 ++++ persist/checkpoint.go | 47 +++++++ persist/persist.go | 59 +++++++++ ptrs.go | 22 ++++ retry.go | 32 +++++ rpc/rpc.go | 182 ++++++++++++++++++++++++++ sas/sas.go | 140 ++++++++++++++++++++ uuid/uuid.go | 85 ++++++++++++ 16 files changed, 1167 insertions(+), 283 deletions(-) create mode 100644 Gopkg.lock create mode 100644 Gopkg.toml create mode 100644 Makefile create mode 100644 aad/jwt.go create mode 100644 auth/token.go create mode 100644 cbs/cbs.go create mode 100644 conn/conn.go create mode 100644 conn/conn_test.go create mode 100644 persist/checkpoint.go create mode 100644 persist/persist.go create mode 100644 ptrs.go create mode 100644 retry.go create mode 100644 rpc/rpc.go create mode 100644 sas/sas.go create mode 100644 uuid/uuid.go diff --git a/.gitignore b/.gitignore index 940794e..c805fda 100644 --- a/.gitignore +++ b/.gitignore @@ -1,288 +1,19 @@ -## Ignore Visual Studio temporary files, build results, and -## files generated by popular Visual Studio add-ons. -## -## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib -# User-specific files -*.suo -*.user -*.userosscache -*.sln.docstates +# Test binary, build with `go test -c` +*.test -# User-specific files (MonoDevelop/Xamarin Studio) -*.userprefs +# Output of the go coverage tool, specifically when used with LiteIDE +*.out -# Build results -[Dd]ebug/ -[Dd]ebugPublic/ -[Rr]elease/ -[Rr]eleases/ -x64/ -x86/ -bld/ -[Bb]in/ -[Oo]bj/ -[Ll]og/ +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ -# Visual Studio 2015 cache/options directory -.vs/ -# Uncomment if you have tasks that create the project's static files in wwwroot -#wwwroot/ +vendor +.idea -# MSTest test Results -[Tt]est[Rr]esult*/ -[Bb]uild[Ll]og.* - -# NUNIT -*.VisualState.xml -TestResult.xml - -# Build Results of an ATL Project -[Dd]ebugPS/ -[Rr]eleasePS/ -dlldata.c - -# .NET Core -project.lock.json -project.fragment.lock.json -artifacts/ -**/Properties/launchSettings.json - -*_i.c -*_p.c -*_i.h -*.ilk -*.meta -*.obj -*.pch -*.pdb -*.pgc -*.pgd -*.rsp -*.sbr -*.tlb -*.tli -*.tlh -*.tmp -*.tmp_proj -*.log -*.vspscc -*.vssscc -.builds -*.pidb -*.svclog -*.scc - -# Chutzpah Test files -_Chutzpah* - -# Visual C++ cache files -ipch/ -*.aps -*.ncb -*.opendb -*.opensdf -*.sdf -*.cachefile -*.VC.db -*.VC.VC.opendb - -# Visual Studio profiler -*.psess -*.vsp -*.vspx -*.sap - -# TFS 2012 Local Workspace -$tf/ - -# Guidance Automation Toolkit -*.gpState - -# ReSharper is a .NET coding add-in -_ReSharper*/ -*.[Rr]e[Ss]harper -*.DotSettings.user - -# JustCode is a .NET coding add-in -.JustCode - -# TeamCity is a build add-in -_TeamCity* - -# DotCover is a Code Coverage Tool -*.dotCover - -# Visual Studio code coverage results -*.coverage -*.coveragexml - -# NCrunch -_NCrunch_* -.*crunch*.local.xml -nCrunchTemp_* - -# MightyMoose -*.mm.* -AutoTest.Net/ - -# Web workbench (sass) -.sass-cache/ - -# Installshield output folder -[Ee]xpress/ - -# DocProject is a documentation generator add-in -DocProject/buildhelp/ -DocProject/Help/*.HxT -DocProject/Help/*.HxC -DocProject/Help/*.hhc -DocProject/Help/*.hhk -DocProject/Help/*.hhp -DocProject/Help/Html2 -DocProject/Help/html - -# Click-Once directory -publish/ - -# Publish Web Output -*.[Pp]ublish.xml -*.azurePubxml -# TODO: Comment the next line if you want to checkin your web deploy settings -# but database connection strings (with potential passwords) will be unencrypted -*.pubxml -*.publishproj - -# Microsoft Azure Web App publish settings. Comment the next line if you want to -# checkin your Azure Web App publish settings, but sensitive information contained -# in these scripts will be unencrypted -PublishScripts/ - -# NuGet Packages -*.nupkg -# The packages folder can be ignored because of Package Restore -**/packages/* -# except build/, which is used as an MSBuild target. -!**/packages/build/ -# Uncomment if necessary however generally it will be regenerated when needed -#!**/packages/repositories.config -# NuGet v3's project.json files produces more ignorable files -*.nuget.props -*.nuget.targets - -# Microsoft Azure Build Output -csx/ -*.build.csdef - -# Microsoft Azure Emulator -ecf/ -rcf/ - -# Windows Store app package directories and files -AppPackages/ -BundleArtifacts/ -Package.StoreAssociation.xml -_pkginfo.txt - -# Visual Studio cache files -# files ending in .cache can be ignored -*.[Cc]ache -# but keep track of directories ending in .cache -!*.[Cc]ache/ - -# Others -ClientBin/ -~$* -*~ -*.dbmdl -*.dbproj.schemaview -*.jfm -*.pfx -*.publishsettings -orleans.codegen.cs - -# Since there are multiple workflows, uncomment next line to ignore bower_components -# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) -#bower_components/ - -# RIA/Silverlight projects -Generated_Code/ - -# Backup & report files from converting an old project file -# to a newer Visual Studio version. Backup files are not needed, -# because we have git ;-) -_UpgradeReport_Files/ -Backup*/ -UpgradeLog*.XML -UpgradeLog*.htm - -# SQL Server files -*.mdf -*.ldf -*.ndf - -# Business Intelligence projects -*.rdl.data -*.bim.layout -*.bim_*.settings - -# Microsoft Fakes -FakesAssemblies/ - -# GhostDoc plugin setting file -*.GhostDoc.xml - -# Node.js Tools for Visual Studio -.ntvs_analysis.dat -node_modules/ - -# Typescript v1 declaration files -typings/ - -# Visual Studio 6 build log -*.plg - -# Visual Studio 6 workspace options file -*.opt - -# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) -*.vbw - -# Visual Studio LightSwitch build output -**/*.HTMLClient/GeneratedArtifacts -**/*.DesktopClient/GeneratedArtifacts -**/*.DesktopClient/ModelManifest.xml -**/*.Server/GeneratedArtifacts -**/*.Server/ModelManifest.xml -_Pvt_Extensions - -# Paket dependency manager -.paket/paket.exe -paket-files/ - -# FAKE - F# Make -.fake/ - -# JetBrains Rider -.idea/ -*.sln.iml - -# CodeRush -.cr/ - -# Python Tools for Visual Studio (PTVS) -__pycache__/ -*.pyc - -# Cake - Uncomment if you are using it -# tools/** -# !tools/packages.config - -# Telerik's JustMock configuration file -*.jmconfig - -# BizTalk build output -*.btp.cs -*.btm.cs -*.odx.cs -*.xsd.cs +.DS_Store diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000..9e80de2 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,84 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + name = "github.com/Azure/go-autorest" + packages = [ + "autorest", + "autorest/adal", + "autorest/azure", + "autorest/date" + ] + revision = "7909b98056dd6f6a9fc9b7745af1810c93c15939" + version = "v10.3.0" + +[[projects]] + name = "github.com/davecgh/go-spew" + packages = ["spew"] + revision = "346938d642f2ec3594ed81d874461961cd0faa76" + version = "v1.1.0" + +[[projects]] + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + +[[projects]] + name = "github.com/pkg/errors" + packages = ["."] + revision = "645ef00459ed84a119197bfb8d8205042c6df63d" + version = "v0.8.0" + +[[projects]] + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + name = "github.com/sirupsen/logrus" + packages = ["."] + revision = "c155da19408a8799da419ed3eeb0cb5db0ad5dbc" + version = "v1.0.5" + +[[projects]] + name = "github.com/stretchr/testify" + packages = ["assert"] + revision = "12b6f73e6084dad08a7c6e575284b177ecafbc71" + version = "v1.2.1" + +[[projects]] + branch = "master" + name = "golang.org/x/crypto" + packages = [ + "pkcs12", + "pkcs12/internal/rc2", + "ssh/terminal" + ] + revision = "80db560fac1fb3e6ac81dbc7f8ae4c061f5257bd" + +[[projects]] + branch = "master" + name = "golang.org/x/sys" + packages = [ + "unix", + "windows" + ] + revision = "c488ab1dd8481ef762f96a79a9577c27825be697" + +[[projects]] + branch = "master" + name = "pack.ag/amqp" + packages = [ + ".", + "internal/testconn" + ] + revision = "fc71119dfd03ed44d0aba09806e4a7d1584b74b1" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "7c743defd623795eff4a690aaf4d7651db81fae6e00d2e432b1ee18d132377c4" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000..0c55cfb --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,7 @@ +[prune] + go-tests = true + unused-packages = true + +[[constraint]] + name = "pack.ag/amqp" + branch = "master" diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..aa218b9 --- /dev/null +++ b/Makefile @@ -0,0 +1,98 @@ +PACKAGE = github.com/Azure/azure-amqp-common-go +DATE ?= $(shell date +%FT%T%z) +VERSION ?= $(shell git describe --tags --always --dirty --match=v* 2> /dev/null || \ + cat $(CURDIR)/.version 2> /dev/null || echo v0) +BIN = $(GOPATH)/bin +BASE = $(GOPATH)/src/$(PACKAGE) +PKGS = $(or $(PKG),$(shell cd $(BASE) && env GOPATH=$(GOPATH) $(GO) list ./... | grep -vE "^$(PACKAGE)/vendor|_examples|templates/")) +TESTPKGS = $(shell env GOPATH=$(GOPATH) $(GO) list -f '{{ if or .TestGoFiles .XTestGoFiles }}{{ .ImportPath }}{{ end }}' $(PKGS)) +GO_FILES = find . -iname '*.go' -type f | grep -v /vendor/ + +GO = go +GODOC = godoc +GOFMT = gofmt +GOCYCLO = gocyclo +DEP = dep + +V = 0 +Q = $(if $(filter 1,$V),,@) +M = $(shell printf "\033[34;1m▶\033[0m") +TIMEOUT = 360 + +.PHONY: all +all: fmt vendor lint vet megacheck | $(BASE) ; $(info $(M) building library…) @ ## Build program + $Q cd $(BASE) && $(GO) build \ + -tags release \ + -ldflags '-X $(PACKAGE)/cmd.Version=$(VERSION) -X $(PACKAGE)/cmd.BuildDate=$(DATE)' + +$(BASE): ; $(info $(M) setting GOPATH…) + @mkdir -p $(dir $@) + @ln -sf $(CURDIR) $@ + +# Tools + +GOLINT = $(BIN)/golint +$(BIN)/golint: | $(BASE) ; $(info $(M) building golint…) + $Q go get github.com/golang/lint/golint + +# Tests + +TEST_TARGETS := test-default test-bench test-short test-verbose test-race test-debug +.PHONY: $(TEST_TARGETS) test-xml check test tests +test-bench: ARGS=-run=__absolutelynothing__ -bench=. ## Run benchmarks +test-short: ARGS=-short ## Run only short tests +test-verbose: ARGS=-v ## Run tests in verbose mode +test-debug: ARGS=-v -debug ## Run tests in verbose mode with debug output +test-race: ARGS=-race ## Run tests with race detector +test-cover: ARGS=-cover ## Run tests in verbose mode with coverage +$(TEST_TARGETS): NAME=$(MAKECMDGOALS:test-%=%) +$(TEST_TARGETS): test +check test tests: cyclo lint vet vendor megacheck | $(BASE) ; $(info $(M) running $(NAME:%=% )tests…) @ ## Run tests + $Q cd $(BASE) && $(GO) test -timeout $(TIMEOUT)s $(ARGS) $(TESTPKGS) + +.PHONY: vet +vet: vendor | $(BASE) $(GOLINT) ; $(info $(M) running vet…) @ ## Run vet + $Q cd $(BASE) && $(GO) vet ./... + +.PHONY: lint +lint: vendor | $(BASE) $(GOLINT) ; $(info $(M) running golint…) @ ## Run golint + $Q cd $(BASE) && ret=0 && for pkg in $(PKGS); do \ + test -z "$$($(GOLINT) $$pkg | tee /dev/stderr)" || ret=1 ; \ + done ; exit $$ret + +.PHONY: megacheck +megacheck: vendor | $(BASE) ; $(info $(M) running megacheck…) @ ## Run megacheck + $Q cd $(BASE) && megacheck + +.PHONY: fmt +fmt: ; $(info $(M) running gofmt…) @ ## Run gofmt on all source files + @ret=0 && for d in $$($(GO) list -f '{{.Dir}}' ./... | grep -v /vendor/); do \ + $(GOFMT) -l -w $$d/*.go || ret=$$? ; \ + done ; exit $$ret + +.PHONY: cyclo +cyclo: ; $(info $(M) running gocyclo...) @ ## Run gocyclo on all source files + $Q cd $(BASE) && $(GOCYCLO) -over 19 $$($(GO_FILES)) +# Dependency management + +Gopkg.lock: Gopkg.toml | $(BASE) ; $(info $(M) updating dependencies…) + $Q cd $(BASE) && $(DEP) ensure + @touch $@ +vendor: Gopkg.lock | $(BASE) ; $(info $(M) retrieving dependencies…) + $Q cd $(BASE) && $(DEP) ensure + @touch $@ + +# Misc + +.PHONY: clean +clean: ; $(info $(M) cleaning…) @ ## Cleanup everything + @rm -rf test/tests.* test/coverage.* + +.PHONY: help +help: + @grep -E '^[ a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \ + awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' + +.PHONY: version +version: + @echo $(VERSION) \ No newline at end of file diff --git a/aad/jwt.go b/aad/jwt.go new file mode 100644 index 0000000..03adfa1 --- /dev/null +++ b/aad/jwt.go @@ -0,0 +1,233 @@ +// Package aad provides an implementation of an Azure Active Directory JWT provider which implements TokenProvider +// from package auth for use with Azure Event Hubs and Service Bus. +package aad + +import ( + "crypto/rsa" + "crypto/x509" + "fmt" + "io/ioutil" + "os" + "strconv" + "time" + + "github.com/Azure/azure-amqp-common-go/auth" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/pkcs12" +) + +const ( + eventhubResourceURI = "https://eventhubs.azure.net/" +) + +type ( + // TokenProviderConfiguration provides configuration parameters for building JWT AAD providers + TokenProviderConfiguration struct { + TenantID string + ClientID string + ClientSecret string + CertificatePath string + CertificatePassword string + ResourceURI string + aadToken *adal.ServicePrincipalToken + Env *azure.Environment + } + + // TokenProvider provides cbs.TokenProvider functionality for Azure Active Directory JWTs + TokenProvider struct { + tokenProvider *adal.ServicePrincipalToken + } + + // JWTProviderOption provides configuration options for constructing AAD Token Providers + JWTProviderOption func(provider *TokenProviderConfiguration) error +) + +// JWTProviderWithAzureEnvironment configures the token provider to use a specific Azure Environment +func JWTProviderWithAzureEnvironment(env *azure.Environment) JWTProviderOption { + return func(config *TokenProviderConfiguration) error { + config.Env = env + return nil + } +} + +// JWTProviderWithEnvironmentVars configures the TokenProvider using the environment variables available +// +// 1. Client Credentials: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID" and +// "AZURE_CLIENT_SECRET" +// +// 2. Client Certificate: attempt to authenticate with a Service Principal via "AZURE_TENANT_ID", "AZURE_CLIENT_ID", +// "AZURE_CERTIFICATE_PATH" and "AZURE_CERTIFICATE_PASSWORD" +// +// 3. Managed Service Identity (MSI): attempt to authenticate via MSI +// +// +// The Azure Environment used can be specified using the name of the Azure Environment set in "AZURE_ENVIRONMENT" var. +func JWTProviderWithEnvironmentVars() JWTProviderOption { + return func(config *TokenProviderConfiguration) error { + config.TenantID = os.Getenv("AZURE_TENANT_ID") + config.ClientID = os.Getenv("AZURE_CLIENT_ID") + config.ClientSecret = os.Getenv("AZURE_CLIENT_SECRET") + config.CertificatePath = os.Getenv("AZURE_CERTIFICATE_PATH") + config.CertificatePassword = os.Getenv("AZURE_CERTIFICATE_PASSWORD") + + if config.Env == nil { + env, err := azureEnvFromEnvironment() + if err != nil { + return err + } + config.Env = env + } + return nil + } +} + +// JWTProviderWithResourceURI configures the token provider to use a specific eventhubResourceURI URI +func JWTProviderWithResourceURI(resourceURI string) JWTProviderOption { + return func(config *TokenProviderConfiguration) error { + config.ResourceURI = resourceURI + return nil + } +} + +// JWTProviderWithAADToken configures the token provider to use a specific Azure Active Directory Service Principal token +func JWTProviderWithAADToken(aadToken *adal.ServicePrincipalToken) JWTProviderOption { + return func(config *TokenProviderConfiguration) error { + config.aadToken = aadToken + return nil + } +} + +// NewJWTProvider builds an Azure Active Directory claims-based security token provider +func NewJWTProvider(opts ...JWTProviderOption) (auth.TokenProvider, error) { + config := &TokenProviderConfiguration{ + ResourceURI: eventhubResourceURI, + } + + for _, opt := range opts { + err := opt(config) + if err != nil { + return nil, err + } + } + + if config.aadToken == nil { + spToken, err := config.NewServicePrincipalToken() + if err != nil { + return nil, err + } + config.aadToken = spToken + } + return &TokenProvider{tokenProvider: config.aadToken}, nil +} + +// NewServicePrincipalToken creates a new Azure Active Directory Service Principal token provider +func (c *TokenProviderConfiguration) NewServicePrincipalToken() (*adal.ServicePrincipalToken, error) { + oauthConfig, err := adal.NewOAuthConfig(c.Env.ActiveDirectoryEndpoint, c.TenantID) + if err != nil { + return nil, err + } + + // 1.Client Credentials + if c.ClientSecret != "" { + log.Debug("creating a token via a service principal client secret") + spToken, err := adal.NewServicePrincipalToken(*oauthConfig, c.ClientID, c.ClientSecret, c.ResourceURI) + if err != nil { + return nil, fmt.Errorf("failed to get oauth token from client credentials: %v", err) + } + if err := spToken.Refresh(); err != nil { + return nil, fmt.Errorf("failed to refersh token: %v", spToken) + } + return spToken, nil + } + + // 2. Client Certificate + if c.CertificatePath != "" { + log.Debug("creating a token via a service principal client certificate") + certData, err := ioutil.ReadFile(c.CertificatePath) + if err != nil { + return nil, fmt.Errorf("failed to read the certificate file (%s): %v", c.CertificatePath, err) + } + certificate, rsaPrivateKey, err := decodePkcs12(certData, c.CertificatePassword) + if err != nil { + return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err) + } + spToken, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, c.ClientID, certificate, rsaPrivateKey, c.ResourceURI) + if err != nil { + return nil, fmt.Errorf("failed to get oauth token from certificate auth: %v", err) + } + if err := spToken.Refresh(); err != nil { + return nil, fmt.Errorf("failed to refersh token: %v", spToken) + } + return spToken, nil + } + + // 3. By default return MSI + log.Debug("creating a token via MSI") + msiEndpoint, err := adal.GetMSIVMEndpoint() + if err != nil { + return nil, err + } + spToken, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, c.ResourceURI) + if err != nil { + return nil, fmt.Errorf("failed to get oauth token from MSI: %v", err) + } + if err := spToken.Refresh(); err != nil { + return nil, fmt.Errorf("failed to refersh token: %v", spToken) + } + return spToken, nil +} + +// GetToken gets a CBS JWT +func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) { + token := t.tokenProvider.Token() + expireTicks, err := strconv.ParseInt(token.ExpiresOn, 10, 64) + if err != nil { + log.Debugf("%v", token.AccessToken) + return nil, err + } + expires := time.Unix(expireTicks, 0) + + if expires.Before(time.Now()) { + log.Debug("refreshing AAD token since it has expired") + if err := t.tokenProvider.Refresh(); err != nil { + log.Error("refreshing AAD token has failed") + return nil, err + } + token = t.tokenProvider.Token() + log.Debug("refreshing AAD token has succeeded") + } + + return auth.NewToken(auth.CBSTokenTypeJWT, token.AccessToken, token.ExpiresOn), nil +} + +func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, certificate, err := pkcs12.Decode(pkcs, password) + if err != nil { + return nil, nil, err + } + + rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) + if !isRsaKey { + return nil, nil, fmt.Errorf("PKCS#12 certificate must contain an RSA private key") + } + + return certificate, rsaPrivateKey, nil +} + +func azureEnvFromEnvironment() (*azure.Environment, error) { + envName := os.Getenv("AZURE_ENVIRONMENT") + + var env azure.Environment + if envName == "" { + env = azure.PublicCloud + } else { + var err error + env, err = azure.EnvironmentFromName(envName) + if err != nil { + return nil, err + } + } + return &env, nil +} diff --git a/auth/token.go b/auth/token.go new file mode 100644 index 0000000..06f3312 --- /dev/null +++ b/auth/token.go @@ -0,0 +1,36 @@ +// Package auth provides an abstraction over claims-based security for Azure Event Hub and Service Bus. +package auth + +const ( + // CBSTokenTypeJWT is the type of token to be used for JWTs. For example Azure Active Directory tokens. + CBSTokenTypeJWT TokenType = "jwt" + // CBSTokenTypeSAS is the type of token to be used for SAS tokens. + CBSTokenTypeSAS TokenType = "servicebus.windows.net:sastoken" +) + +type ( + // TokenType represents types of tokens known for claims-based auth + TokenType string + + // Token contains all of the information to negotiate authentication + Token struct { + // TokenType is the type of CBS token + TokenType TokenType + Token string + Expiry string + } + + // TokenProvider abstracts the fetching of authentication tokens + TokenProvider interface { + GetToken(uri string) (*Token, error) + } +) + +// NewToken constructs a new auth token +func NewToken(tokenType TokenType, token, expiry string) *Token { + return &Token{ + TokenType: tokenType, + Token: token, + Expiry: expiry, + } +} diff --git a/cbs/cbs.go b/cbs/cbs.go new file mode 100644 index 0000000..735f665 --- /dev/null +++ b/cbs/cbs.go @@ -0,0 +1,56 @@ +// Package cbs provides the functionality for negotiating claims-based security over AMQP for use in Azure Service Bus +// and Event Hubs. +package cbs + +import ( + "context" + "time" + + "github.com/Azure/azure-amqp-common-go/auth" + "github.com/Azure/azure-amqp-common-go/rpc" + log "github.com/sirupsen/logrus" + "pack.ag/amqp" +) + +const ( + cbsAddress = "$cbs" + cbsOperationKey = "operation" + cbsOperationPutToken = "put-token" + cbsTokenTypeKey = "type" + cbsAudienceKey = "name" + cbsExpirationKey = "expiration" +) + +// NegotiateClaim attempts to put a token to the $cbs management endpoint to negotiate auth for the given audience +func NegotiateClaim(ctx context.Context, audience string, conn *amqp.Client, provider auth.TokenProvider) error { + link, err := rpc.NewLink(conn, cbsAddress) + if err != nil { + return err + } + defer link.Close() + + token, err := provider.GetToken(audience) + if err != nil { + return err + } + + log.Debugf("negotiating claim for audience %s with token type %s and expiry of %s", audience, token.TokenType, token.Expiry) + msg := &amqp.Message{ + Value: token.Token, + ApplicationProperties: map[string]interface{}{ + cbsOperationKey: cbsOperationPutToken, + cbsTokenTypeKey: string(token.TokenType), + cbsAudienceKey: audience, + cbsExpirationKey: token.Expiry, + }, + } + + res, err := link.RetryableRPC(ctx, 3, 1*time.Second, msg) + if err != nil { + log.Error(err) + return err + } + + log.Debugf("negotiated with response code %d and message: %s", res.Code, res.Description) + return nil +} diff --git a/conn/conn.go b/conn/conn.go new file mode 100644 index 0000000..56960d9 --- /dev/null +++ b/conn/conn.go @@ -0,0 +1,47 @@ +package conn + +import ( + "errors" + "fmt" + "regexp" +) + +var ( + connStrRegex = regexp.MustCompile(`Endpoint=sb:\/\/(?P.+?);SharedAccessKeyName=(?P.+?);SharedAccessKey=(?P.+?);EntityPath=(?P.+)`) + hostStrRegex = regexp.MustCompile(`^(?P.+?)\.(.+?)\/`) +) + +type ( + // ParsedConn is the structure of a parsed Service Bus or Event Hub connection string. + ParsedConn struct { + Host string + Suffix string + Namespace string + HubName string + KeyName string + Key string + } +) + +// newParsedConnection is a constructor for a parsedConn and verifies each of the inputs is non-null. +func newParsedConnection(host, suffix, namespace, hubName, keyName, key string) (*ParsedConn, error) { + if host == "" || keyName == "" || key == "" { + return nil, errors.New("connection string contains an empty entry") + } + return &ParsedConn{ + Host: "amqps://" + host, + Suffix: suffix, + Namespace: namespace, + KeyName: keyName, + Key: key, + HubName: hubName, + }, nil +} + +// ParsedConnectionFromStr takes a string connection string from the Azure portal and returns the parsed representation. +func ParsedConnectionFromStr(connStr string) (*ParsedConn, error) { + matches := connStrRegex.FindStringSubmatch(connStr) + namespaceMatches := hostStrRegex.FindStringSubmatch(matches[1]) + fmt.Println(matches[1], namespaceMatches[2], namespaceMatches[1], matches[2], matches[3]) + return newParsedConnection(matches[1], namespaceMatches[2], namespaceMatches[1], matches[4], matches[2], matches[3]) +} diff --git a/conn/conn_test.go b/conn/conn_test.go new file mode 100644 index 0000000..f810ff6 --- /dev/null +++ b/conn/conn_test.go @@ -0,0 +1,25 @@ +package conn + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + namespace = "mynamespace" + keyName = "keyName" + secret = "superSecret" + hubName = "myhub" + connStr = "Endpoint=sb://" + namespace + ".servicebus.windows.net/;SharedAccessKeyName=" + keyName + ";SharedAccessKey=" + secret + ";EntityPath=" + hubName +) + +func TestParsedConnectionFromStr(t *testing.T) { + parsed, err := ParsedConnectionFromStr(connStr) + assert.Nil(t, err, err) + assert.Equal(t, "amqps://"+namespace+".servicebus.windows.net/", parsed.Host) + assert.Equal(t, namespace, parsed.Namespace) + assert.Equal(t, keyName, parsed.KeyName) + assert.Equal(t, secret, parsed.Key) + assert.Equal(t, hubName, parsed.HubName) +} diff --git a/persist/checkpoint.go b/persist/checkpoint.go new file mode 100644 index 0000000..568fc63 --- /dev/null +++ b/persist/checkpoint.go @@ -0,0 +1,47 @@ +package persist + +import ( + "time" +) + +const ( + // StartOfStream is a constant defined to represent the start of a partition stream in EventHub. + StartOfStream = "-1" + + // EndOfStream is a constant defined to represent the current end of a partition stream in EventHub. + // This can be used as an offset argument in receiver creation to start receiving from the latest + // event, instead of a specific offset or point in time. + EndOfStream = "@latest" +) + +type ( + // Checkpoint is the information needed to determine the last message processed + Checkpoint struct { + Offset string `json:"offset"` + SequenceNumber int64 `json:"sequenceNumber"` + EnqueueTime time.Time `json:"enqueueTime"` + } +) + +// NewCheckpointFromStartOfStream returns a checkpoint for the start of the stream +func NewCheckpointFromStartOfStream() Checkpoint { + return Checkpoint{ + Offset: StartOfStream, + } +} + +// NewCheckpointFromEndOfStream returns a checkpoint for the end of the stream +func NewCheckpointFromEndOfStream() Checkpoint { + return Checkpoint{ + Offset: EndOfStream, + } +} + +// NewCheckpoint contains the information needed to checkpoint Event Hub progress +func NewCheckpoint(offset string, sequence int64, enqueueTime time.Time) Checkpoint { + return Checkpoint{ + Offset: offset, + SequenceNumber: sequence, + EnqueueTime: enqueueTime, + } +} diff --git a/persist/persist.go b/persist/persist.go new file mode 100644 index 0000000..ae941e3 --- /dev/null +++ b/persist/persist.go @@ -0,0 +1,59 @@ +// Package persist provides abstract structures for checkpoint persistence. +package persist + +import ( + "github.com/pkg/errors" + "path" + "sync" +) + +type ( + // CheckpointPersister provides persistence for the received offset for a given namespace, hub name, consumer group, partition Id and + // offset so that if a receiver where to be interrupted, it could resume after the last consumed event. + CheckpointPersister interface { + Write(namespace, name, consumerGroup, partitionID string, checkpoint Checkpoint) error + Read(namespace, name, consumerGroup, partitionID string) (Checkpoint, error) + } + + // MemoryPersister is a default implementation of a Hub CheckpointPersister, which will persist offset information in + // memory. + MemoryPersister struct { + values map[string]Checkpoint + mu sync.Mutex + } +) + +// NewMemoryPersister creates a new in-memory storage for checkpoints +// +// MemoryPersister is only intended to be shared with EventProcessorHosts within the same process. This implementation +// is a toy. You should probably use the Azure Storage implementation or any other that provides durable storage for +// checkpoints. +func NewMemoryPersister() *MemoryPersister { + return &MemoryPersister{ + values: make(map[string]Checkpoint), + } +} + +func (p *MemoryPersister) Write(namespace, name, consumerGroup, partitionID string, checkpoint Checkpoint) error { + p.mu.Lock() + defer p.mu.Unlock() + + key := getPersistenceKey(namespace, name, consumerGroup, partitionID) + p.values[key] = checkpoint + return nil +} + +func (p *MemoryPersister) Read(namespace, name, consumerGroup, partitionID string) (Checkpoint, error) { + p.mu.Lock() + defer p.mu.Unlock() + + key := getPersistenceKey(namespace, name, consumerGroup, partitionID) + if offset, ok := p.values[key]; ok { + return offset, nil + } + return NewCheckpointFromStartOfStream(), errors.Errorf("could not read the offset for the key %s", key) +} + +func getPersistenceKey(namespace, name, consumerGroup, partitionID string) string { + return path.Join(namespace, name, consumerGroup, partitionID) +} diff --git a/ptrs.go b/ptrs.go new file mode 100644 index 0000000..f326c49 --- /dev/null +++ b/ptrs.go @@ -0,0 +1,22 @@ +package common + +// PtrBool takes a boolean and returns a pointer to that bool. For use in literal pointers, ptrBool(true) -> *bool +func PtrBool(toPtr bool) *bool { + return &toPtr +} + +// PtrString takes a string and returns a pointer to that string. For use in literal pointers, +// PtrString(fmt.Sprintf("..", foo)) -> *string +func PtrString(toPtr string) *string { + return &toPtr +} + +// PtrInt32 takes a int32 and returns a pointer to that int32. For use in literal pointers, ptrInt32(1) -> *int32 +func PtrInt32(number int32) *int32 { + return &number +} + +// PtrInt64 takes a int64 and returns a pointer to that int64. For use in literal pointers, ptrInt64(1) -> *int64 +func PtrInt64(number int64) *int64 { + return &number +} diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..e28a007 --- /dev/null +++ b/retry.go @@ -0,0 +1,32 @@ +package common + +import ( + "time" +) + +// Retryable represents an error which should be able to be retried +type Retryable string + +// Error implementation for Retryable +func (r Retryable) Error() string { + return string(r) +} + +// Retry will attempt to retry an action a number of times if the action returns a retryable error +func Retry(times int, delay time.Duration, action func() (interface{}, error)) (interface{}, error) { + var lastErr error + for i := 0; i < times; i++ { + item, err := action() + if err != nil { + if err, ok := err.(Retryable); ok { + lastErr = err + time.Sleep(delay) + continue + } else { + return nil, err + } + } + return item, nil + } + return nil, lastErr +} diff --git a/rpc/rpc.go b/rpc/rpc.go new file mode 100644 index 0000000..a46190f --- /dev/null +++ b/rpc/rpc.go @@ -0,0 +1,182 @@ +// Package rpc provides functionality for request / reply messaging. It is used by package mgmt and cbs. +package rpc + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/Azure/azure-amqp-common-go" + "github.com/Azure/azure-amqp-common-go/uuid" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "pack.ag/amqp" +) + +const ( + replyPostfix = "-reply-to-" + statusCodeKey = "status-code" + descriptionKey = "status-description" +) + +type ( + // Link is the bidirectional communication structure used for CBS negotiation + Link struct { + session *amqp.Session + receiver *amqp.Receiver + sender *amqp.Sender + clientAddress string + rpcMu sync.Mutex + id string + } + + // Response is the simplified response structure from an RPC like call + Response struct { + Code int + Description string + Message *amqp.Message + } +) + +// NewLink will build a new request response link +func NewLink(conn *amqp.Client, address string) (*Link, error) { + authSession, err := conn.NewSession() + if err != nil { + return nil, err + } + + authSender, err := authSession.NewSender( + amqp.LinkTargetAddress(address), + ) + if err != nil { + return nil, err + } + + linkID, err := uuid.NewV4() + if err != nil { + return nil, err + } + + id := linkID.String() + clientAddress := strings.Replace("$", "", address, -1) + replyPostfix + id + authReceiver, err := authSession.NewReceiver( + amqp.LinkSourceAddress(address), + amqp.LinkTargetAddress(clientAddress), + ) + if err != nil { + return nil, err + } + + return &Link{ + sender: authSender, + receiver: authReceiver, + session: authSession, + clientAddress: clientAddress, + id: id, + }, nil +} + +// RetryableRPC attempts to retry a request a number of times with delay +func (l *Link) RetryableRPC(ctx context.Context, times int, delay time.Duration, msg *amqp.Message) (*Response, error) { + res, err := common.Retry(times, delay, func() (interface{}, error) { + res, err := l.RPC(ctx, msg) + if err != nil { + log.Debugf("error in RPC via link %s: %v", l.id, err) + return nil, err + } + + switch { + case res.Code >= 200 && res.Code < 300: + log.Debugf("successful rpc on link %s: status code %d and description: %s", l.id, res.Code, res.Description) + return res, nil + case res.Code >= 500: + errMessage := fmt.Sprintf("server error link %s: status code %d and description: %s", l.id, res.Code, res.Description) + log.Debugln(errMessage) + return nil, common.Retryable(errMessage) + default: + errMessage := fmt.Sprintf("unhandled error link %s: status code %d and description: %s", l.id, res.Code, res.Description) + log.Debugln(errMessage) + return nil, common.Retryable(errMessage) + } + }) + if err != nil { + return nil, err + } + return res.(*Response), nil +} + +// RPC sends a request and waits on a response for that request +func (l *Link) RPC(ctx context.Context, msg *amqp.Message) (*Response, error) { + l.rpcMu.Lock() + defer l.rpcMu.Unlock() + + if msg.Properties == nil { + msg.Properties = &amqp.MessageProperties{} + } + msg.Properties.ReplyTo = l.clientAddress + + err := l.sender.Send(ctx, msg) + if err != nil { + return nil, err + } + + res, err := l.receiver.Receive(ctx) + if err != nil { + return nil, err + } + + statusCode, ok := res.ApplicationProperties[statusCodeKey].(int32) + if !ok { + return nil, errors.New("status codes was not found on rpc message") + } + + description, ok := res.ApplicationProperties[descriptionKey].(string) + if !ok { + return nil, errors.New("description was not found on rpc message") + } + + return &Response{ + Code: int(statusCode), + Description: description, + Message: res, + }, err +} + +// Close the link receiver, sender and session +func (l *Link) Close() error { + if err := l.closeReceiver(); err != nil { + _ = l.closeSender() + _ = l.closeSession() + return err + } + + if err := l.closeSender(); err != nil { + _ = l.closeSession() + return err + } + + return l.closeSession() +} + +func (l *Link) closeReceiver() error { + if l.receiver != nil { + return l.receiver.Close() + } + return nil +} + +func (l *Link) closeSender() error { + if l.sender != nil { + return l.sender.Close() + } + return nil +} + +func (l *Link) closeSession() error { + if l.session != nil { + return l.session.Close() + } + return nil +} diff --git a/sas/sas.go b/sas/sas.go new file mode 100644 index 0000000..0dd8c4d --- /dev/null +++ b/sas/sas.go @@ -0,0 +1,140 @@ +// Package sas provides SAS token functionality which implements TokenProvider from package auth for use with Azure +// Event Hubs and Service Bus. +package sas + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-amqp-common-go/auth" + "github.com/Azure/azure-amqp-common-go/conn" + "github.com/pkg/errors" +) + +type ( + // Signer provides SAS token generation for use in Service Bus and Event Hub + Signer struct { + namespace string + keyName string + key string + } + + // TokenProvider is a SAS claims-based security token provider + TokenProvider struct { + signer *Signer + } + + // TokenProviderOption provides configuration options for SAS Token Providers + TokenProviderOption func(*TokenProvider) error +) + +// TokenProviderWithEnvironmentVars creates a new SAS TokenProvider from environment variables +// +// There are two sets of environment variables which can produce a SAS TokenProvider +// +// 1) Expected Environment Variables: +// - "EVENTHUB_NAMESPACE" the namespace of the Event Hub instance +// - "EVENTHUB_KEY_NAME" the name of the Event Hub key +// - "EVENTHUB_KEY_VALUE" the secret for the Event Hub key named in "EVENTHUB_KEY_NAME" +// +// 2) Expected Environment Variable: +// - "EVENTHUB_CONNECTION_STRING" connection string from the Azure portal +// +// looks like: Endpoint=sb://namespace.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=superSecret1234= +func TokenProviderWithEnvironmentVars() TokenProviderOption { + return func(provider *TokenProvider) error { + connStr := os.Getenv("EVENTHUB_CONNECTION_STRING") + if connStr != "" { + parsed, err := conn.ParsedConnectionFromStr(connStr) + if err != nil { + return err + } + provider.signer = NewSigner(parsed.Namespace, parsed.KeyName, parsed.Key) + return nil + } + + var ( + keyName = os.Getenv("EVENTHUB_KEY_NAME") + keyValue = os.Getenv("EVENTHUB_KEY_VALUE") + namespace = os.Getenv("EVENTHUB_NAMESPACE") + ) + + if keyName == "" || keyValue == "" || namespace == "" { + return errors.New("unable to build SAS token provider because (EVENTHUB_KEY_NAME, EVENTHUB_KEY_VALUE and EVENTHUB_NAMESPACE) were empty, and EVENTHUB_CONNECTION_STRING was empty") + } + provider.signer = NewSigner(namespace, keyName, keyValue) + return nil + } +} + +// TokenProviderWithNamespaceAndKey configures a SAS TokenProvider to use the given namespace and key combination for signing +func TokenProviderWithNamespaceAndKey(namespace, keyName, key string) TokenProviderOption { + return func(provider *TokenProvider) error { + provider.signer = NewSigner(namespace, keyName, key) + return nil + } +} + +// NewTokenProvider builds a SAS claims-based security token provider +func NewTokenProvider(opts ...TokenProviderOption) (auth.TokenProvider, error) { + provider := new(TokenProvider) + for _, opt := range opts { + err := opt(provider) + if err != nil { + return nil, err + } + } + return provider, nil +} + +// GetToken gets a CBS SAS token +func (t *TokenProvider) GetToken(audience string) (*auth.Token, error) { + signed, expiry := t.signer.SignWithDuration(audience, 2*time.Hour) + return auth.NewToken(auth.CBSTokenTypeSAS, signed, expiry), nil +} + +// NewSigner builds a new SAS signer for use in generation Service Bus and Event Hub SAS tokens +func NewSigner(namespace, keyName, key string) *Signer { + return &Signer{ + namespace: namespace, + keyName: keyName, + key: key, + } +} + +// SignWithDuration signs a given for a period of time from now +func (s *Signer) SignWithDuration(uri string, interval time.Duration) (signed, expiry string) { + expiry = signatureExpiry(time.Now(), interval) + return s.SignWithExpiry(uri, expiry), expiry +} + +// SignWithExpiry signs a given uri with a given expiry string +func (s *Signer) SignWithExpiry(uri, expiry string) string { + audience := strings.ToLower(url.QueryEscape(uri)) + sts := stringToSign(audience, expiry) + sig := s.signString(sts) + return fmt.Sprintf("SharedAccessSignature sr=%s&sig=%s&se=%s&skn=%s", audience, sig, expiry, s.keyName) +} + +func signatureExpiry(from time.Time, interval time.Duration) string { + t := from.Add(interval).Round(time.Second).Unix() + return strconv.FormatInt(t, 10) +} + +func stringToSign(uri, expiry string) string { + return uri + "\n" + expiry +} + +func (s *Signer) signString(str string) string { + h := hmac.New(sha256.New, []byte(s.key)) + h.Write([]byte(str)) + encodedSig := base64.StdEncoding.EncodeToString(h.Sum(nil)) + return url.QueryEscape(encodedSig) +} diff --git a/uuid/uuid.go b/uuid/uuid.go new file mode 100644 index 0000000..549939c --- /dev/null +++ b/uuid/uuid.go @@ -0,0 +1,85 @@ +package uuid + +import ( + "crypto/rand" + "encoding/hex" +) + +// Size of a UUID in bytes. +const Size = 16 + +// UUID versions +const ( + _ byte = iota + _ + _ + _ + V4 + _ + + VariantNCS byte = iota + VariantRFC4122 + VariantMicrosoft + VariantFuture +) + +var ( + randomReader = rand.Reader + + // Nil is special form of UUID that is specified to have all + // 128 bits set to zero. + Nil = UUID{} +) + +type ( + // UUID representation compliant with specification + // described in RFC 4122. + UUID [Size]byte +) + +// NewV4 returns random generated UUID. +func NewV4() (UUID, error) { + u := UUID{} + if _, err := randomReader.Read(u[:]); err != nil { + return Nil, err + } + u.setVersion(V4) + u.setVariant(VariantRFC4122) + + return u, nil +} + +func (u *UUID) setVersion(v byte) { + u[6] = (u[6] & 0x0f) | (v << 4) +} + +func (u *UUID) setVariant(v byte) { + switch v { + case VariantNCS: + u[8] = (u[8]&(0xff>>1) | (0x00 << 7)) + case VariantRFC4122: + u[8] = (u[8]&(0xff>>2) | (0x02 << 6)) + case VariantMicrosoft: + u[8] = (u[8]&(0xff>>3) | (0x06 << 5)) + case VariantFuture: + fallthrough + default: + u[8] = (u[8]&(0xff>>3) | (0x07 << 5)) + } +} + +func (u UUID) String() string { + buf := make([]byte, 36) + + hex.Encode(buf[0:8], u[0:4]) + buf[8] = '-' + hex.Encode(buf[9:13], u[4:6]) + buf[13] = '-' + hex.Encode(buf[14:18], u[6:8]) + buf[18] = '-' + hex.Encode(buf[19:23], u[8:10]) + buf[23] = '-' + hex.Encode(buf[24:], u[10:]) + + return string(buf) +}