Fixing race condition when closing the RPC link (#55)

Fixing a potential race condition in rpc.go - if the client was closed while we were about to send we could end up trying to access a nil map.

(saw in tests in service bus)
This commit is contained in:
Richard Park 2021-09-01 13:30:37 -07:00 коммит произвёл GitHub
Родитель 65ee81bc7d
Коммит 2a1e59e76f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 68 добавлений и 5 удалений

Просмотреть файл

@ -1,5 +1,10 @@
# Change Log
## `v3.1.2`
- Fixing a potential race condition when an RPC link is shut down while still sending requests
or handling responses.
[PR#55](https://github.com/Azure/azure-amqp-common-go/pull/55)
## `v3.1.1`
- Change `Link` so it can handle parallel requests.
[PR#52](https://github.com/Azure/azure-amqp-common-go/pull/52)

Просмотреть файл

@ -243,9 +243,8 @@ func (l *Link) startResponseRouter() {
continue
}
ch := l.deleteFromMap(autogenMessageId)
ch := l.deleteChannelFromMap(autogenMessageId)
// there's no legitimate case where this should be nil - purely defensive.
if ch != nil {
ch <- rpcResponse{message: res, err: err}
}
@ -286,10 +285,14 @@ func (l *Link) RPC(ctx context.Context, msg *amqp.Message) (*Response, error) {
responseCh := l.addChannelToMap(messageID)
if responseCh == nil {
return nil, amqp.ErrLinkClosed
}
err = l.sender.Send(ctx, msg)
if err != nil {
l.deleteFromMap(messageID)
l.deleteChannelFromMap(messageID)
tab.For(ctx).Error(err)
return nil, err
}
@ -298,7 +301,7 @@ func (l *Link) RPC(ctx context.Context, msg *amqp.Message) (*Response, error) {
select {
case <-ctx.Done():
l.deleteFromMap(messageID)
l.deleteChannelFromMap(messageID)
res, err = nil, ctx.Err()
case resp := <-responseCh:
// this will get triggered by the loop in 'startReceiverRouter' when it receives
@ -408,20 +411,36 @@ func (l *Link) closeSession(ctx context.Context) error {
return nil
}
// addChannelToMap adds a channel which will be used by the response router to
// notify when there is a response to the request.
// If l.responseMap is nil (for instance, via broadcastError) this function will
// return nil.
func (l *Link) addChannelToMap(messageID string) chan rpcResponse {
l.responseMu.Lock()
defer l.responseMu.Unlock()
if l.responseMap == nil {
return nil
}
responseCh := make(chan rpcResponse, 1)
l.responseMap[messageID] = responseCh
return responseCh
}
func (l *Link) deleteFromMap(messageID string) chan rpcResponse {
// deleteChannelFromMap removes the message from our internal map and returns
// a channel that the corresponding RPC() call is waiting on.
// If l.responseMap is nil (for instance, via broadcastError) this function will
// return nil.
func (l *Link) deleteChannelFromMap(messageID string) chan rpcResponse {
l.responseMu.Lock()
defer l.responseMu.Unlock()
if l.responseMap == nil {
return nil
}
ch := l.responseMap[messageID]
delete(l.responseMap, messageID)

Просмотреть файл

@ -241,6 +241,45 @@ func TestRPCFailedSend(t *testing.T) {
require.EqualValues(t, fakeUUID.String(), sender.Sent[0].Properties.MessageID, "Sent message contains a uniquely generated ID")
}
func TestRPCNilMessageMap(t *testing.T) {
fakeSender := &fakeSender{}
fakeReceiver := &fakeReceiver{
Responses: []rpcResponse{
// this should let us see what deleteChannelFromMap does
{amqpMessageWithCorrelationId("hello"), nil},
{nil, amqp.ErrLinkClosed},
},
}
link := &Link{
sender: fakeSender,
receiver: fakeReceiver,
// responseMap is nil if the broadcastError() function is called. Since this can be
// at any time our individual map functions need to handle the map not being
// there.
responseMap: nil,
startResponseRouterOnce: &sync.Once{},
uuidNewV4: uuid.NewV4,
}
// sanity check - all the map/channel functions are returning nil
require.Nil(t, link.addChannelToMap("hello"))
require.Nil(t, link.deleteChannelFromMap("hello"))
link.startResponseRouter()
require.Empty(t, fakeReceiver.Responses, "All responses are used")
// we're not testing the responseRouter for this second part, so just short-circuit
// the running.
link.startResponseRouterOnce.Do(func() {})
// now check that sending can handle it.
resp, err := link.RPC(context.Background(), &amqp.Message{})
require.Error(t, err, amqp.ErrLinkClosed.Error())
require.Nil(t, resp)
}
func amqpMessageWithCorrelationId(id string) *amqp.Message {
return &amqp.Message{
Data: [][]byte{[]byte(fmt.Sprintf("ID was %s", id))},