From 0d3cfff99f4428b308d63d9386fec9cd86475faa Mon Sep 17 00:00:00 2001 From: Adam Langley Date: Thu, 26 Apr 2012 12:05:35 -0400 Subject: [PATCH] ssh: fix deadlock The code was taking locks in the wrong order. Fixes golang/go#3570. R=fullung CC=golang-dev https://golang.org/cl/6123058 --- ssh/channel.go | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/ssh/channel.go b/ssh/channel.go index 4f8050a6..415d8d89 100644 --- a/ssh/channel.go +++ b/ssh/channel.go @@ -229,16 +229,35 @@ func (edc extendedDataChannel) Write(data []byte) (n int, err error) { } func (c *channel) Read(data []byte) (n int, err error) { + n, err, windowAdjustment := c.read(data) + + if windowAdjustment > 0 { + packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ + PeersId: c.theirId, + AdditionalBytes: windowAdjustment, + }) + c.serverConn.lock.Lock() + err = c.serverConn.writePacket(packet) + c.serverConn.lock.Unlock() + if err != nil { + return + } + } + + return +} + +func (c *channel) read(data []byte) (n int, err error, windowAdjustment uint32) { c.lock.Lock() defer c.lock.Unlock() if c.err != nil { - return 0, c.err + return 0, c.err, 0 } for { if c.theySentEOF || c.theyClosed || c.dead { - return 0, io.EOF + return 0, io.EOF, 0 } if len(c.pendingRequests) > 0 { @@ -251,7 +270,7 @@ func (c *channel) Read(data []byte) (n int, err error) { copy(c.pendingRequests, oldPendingRequests[1:]) } - return 0, req + return 0, req, 0 } if c.length > 0 { @@ -263,20 +282,11 @@ func (c *channel) Read(data []byte) (n int, err error) { c.head = 0 } - windowAdjustment := uint32(len(c.pendingData)-c.length) - c.myWindow - if windowAdjustment >= uint32(len(c.pendingData)/2) { - packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ - PeersId: c.theirId, - AdditionalBytes: windowAdjustment, - }) - c.serverConn.lock.Lock() - err = c.serverConn.writePacket(packet) - c.serverConn.lock.Unlock() - if err != nil { - return - } - c.myWindow += windowAdjustment + windowAdjustment = uint32(len(c.pendingData)-c.length) - c.myWindow + if windowAdjustment < uint32(len(c.pendingData)/2) { + windowAdjustment = 0 } + c.myWindow += windowAdjustment return }