ssh: handle bad servers better.
This change prevents bad servers from crashing a client by sending an invalid channel ID. It also makes the client disconnect in more cases of invalid messages from a server and cleans up the client channels in the event of a disconnect. R=dave CC=golang-dev https://golang.org/cl/6099050
This commit is contained in:
Родитель
58afe880f1
Коммит
bcdd6a2fd3
|
@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|||
// mainLoop reads incoming messages and routes channel messages
|
||||
// to their respective ClientChans.
|
||||
func (c *ClientConn) mainLoop() {
|
||||
// TODO(dfc) signal the underlying close to all channels
|
||||
defer c.Close()
|
||||
defer func() {
|
||||
// We don't check, for example, that the channel IDs from the
|
||||
// server are valid before using them. Thus a bad server can
|
||||
// cause us to panic, but we don't want to crash the program.
|
||||
recover()
|
||||
|
||||
c.Close()
|
||||
c.closeAll()
|
||||
}()
|
||||
|
||||
for {
|
||||
packet, err := c.readPacket()
|
||||
if err != nil {
|
||||
|
@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() {
|
|||
case msgChannelData:
|
||||
if len(packet) < 9 {
|
||||
// malformed data packet
|
||||
break
|
||||
return
|
||||
}
|
||||
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
||||
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
|
||||
packet = packet[9:]
|
||||
c.getChan(peersId).stdout.handleData(packet[:length])
|
||||
length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
||||
packet = packet[9:]
|
||||
|
||||
if length != uint32(len(packet)) {
|
||||
return
|
||||
}
|
||||
c.getChan(peersId).stdout.handleData(packet)
|
||||
case msgChannelExtendedData:
|
||||
if len(packet) < 13 {
|
||||
// malformed data packet
|
||||
break
|
||||
return
|
||||
}
|
||||
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
||||
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
||||
if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
|
||||
packet = packet[13:]
|
||||
// RFC 4254 5.2 defines data_type_code 1 to be data destined
|
||||
// for stderr on interactive sessions. Other data types are
|
||||
// silently discarded.
|
||||
if datatype == 1 {
|
||||
c.getChan(peersId).stderr.handleData(packet[:length])
|
||||
}
|
||||
length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
|
||||
packet = packet[13:]
|
||||
|
||||
if length != uint32(len(packet)) {
|
||||
return
|
||||
}
|
||||
// RFC 4254 5.2 defines data_type_code 1 to be data destined
|
||||
// for stderr on interactive sessions. Other data types are
|
||||
// silently discarded.
|
||||
if datatype == 1 {
|
||||
c.getChan(peersId).stderr.handleData(packet)
|
||||
}
|
||||
default:
|
||||
switch msg := decode(packet).(type) {
|
||||
|
@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() {
|
|||
case *windowAdjustMsg:
|
||||
if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
|
||||
// invalid window update
|
||||
break
|
||||
return
|
||||
}
|
||||
case *disconnectMsg:
|
||||
break
|
||||
return
|
||||
default:
|
||||
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
|
||||
}
|
||||
|
@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan {
|
|||
func (c *chanlist) getChan(id uint32) *clientChan {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if id >= uint32(len(c.chans)) {
|
||||
return nil
|
||||
}
|
||||
return c.chans[int(id)]
|
||||
}
|
||||
|
||||
|
@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) {
|
|||
c.chans[int(id)] = nil
|
||||
}
|
||||
|
||||
func (c *chanlist) closeAll() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
for _, ch := range c.chans {
|
||||
if ch == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ch.theyClosed = true
|
||||
ch.stdout.eof()
|
||||
ch.stderr.eof()
|
||||
close(ch.msg)
|
||||
}
|
||||
}
|
||||
|
||||
// A chanWriter represents the stdin of a remote process.
|
||||
type chanWriter struct {
|
||||
win *window
|
||||
|
|
|
@ -275,6 +275,20 @@ func TestExitWithoutStatusOrSignal(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestInvalidServerMessage(t *testing.T) {
|
||||
conn := dial(sendInvalidRecord, t)
|
||||
defer conn.Close()
|
||||
session, err := conn.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to request new session: %s", err)
|
||||
}
|
||||
// Make sure that we closed all the clientChans when the connection
|
||||
// failed.
|
||||
session.wait()
|
||||
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
type exitStatusMsg struct {
|
||||
PeersId uint32
|
||||
Request string
|
||||
|
@ -373,3 +387,14 @@ func sendSignal(signal string, ch *channel) {
|
|||
}
|
||||
ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
|
||||
}
|
||||
|
||||
func sendInvalidRecord(ch *channel) {
|
||||
defer ch.Close()
|
||||
packet := make([]byte, 1+4+4+1)
|
||||
packet[0] = msgChannelData
|
||||
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
|
||||
marshalUint32(packet[5:], 1)
|
||||
packet[9] = 42
|
||||
|
||||
ch.serverConn.writePacket(packet)
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче