ssh: handle bad servers better.

This change prevents bad servers from crashing a client by sending an
invalid channel ID. It also makes the client disconnect in more cases
of invalid messages from a server and cleans up the client channels
in the event of a disconnect.

R=dave
CC=golang-dev
https://golang.org/cl/6099050
This commit is contained in:
Adam Langley 2012-04-24 13:46:22 -04:00
Родитель 58afe880f1
Коммит bcdd6a2fd3
2 изменённых файлов: 75 добавлений и 17 удалений

Просмотреть файл

@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
// mainLoop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
// TODO(dfc) signal the underlying close to all channels
defer c.Close()
defer func() {
// We don't check, for example, that the channel IDs from the
// server are valid before using them. Thus a bad server can
// cause us to panic, but we don't want to crash the program.
recover()
c.Close()
c.closeAll()
}()
for {
packet, err := c.readPacket()
if err != nil {
@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() {
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
break
return
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
packet = packet[9:]
c.getChan(peersId).stdout.handleData(packet[:length])
length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
packet = packet[9:]
if length != uint32(len(packet)) {
return
}
c.getChan(peersId).stdout.handleData(packet)
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
break
return
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
packet = packet[13:]
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
c.getChan(peersId).stderr.handleData(packet[:length])
}
length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
packet = packet[13:]
if length != uint32(len(packet)) {
return
}
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
c.getChan(peersId).stderr.handleData(packet)
}
default:
switch msg := decode(packet).(type) {
@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() {
case *windowAdjustMsg:
if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
// invalid window update
break
return
}
case *disconnectMsg:
break
return
default:
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
}
@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan {
func (c *chanlist) getChan(id uint32) *clientChan {
c.Lock()
defer c.Unlock()
if id >= uint32(len(c.chans)) {
return nil
}
return c.chans[int(id)]
}
@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) {
c.chans[int(id)] = nil
}
func (c *chanlist) closeAll() {
c.Lock()
defer c.Unlock()
for _, ch := range c.chans {
if ch == nil {
continue
}
ch.theyClosed = true
ch.stdout.eof()
ch.stderr.eof()
close(ch.msg)
}
}
// A chanWriter represents the stdin of a remote process.
type chanWriter struct {
win *window

Просмотреть файл

@ -275,6 +275,20 @@ func TestExitWithoutStatusOrSignal(t *testing.T) {
}
}
func TestInvalidServerMessage(t *testing.T) {
conn := dial(sendInvalidRecord, t)
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("Unable to request new session: %s", err)
}
// Make sure that we closed all the clientChans when the connection
// failed.
session.wait()
defer session.Close()
}
type exitStatusMsg struct {
PeersId uint32
Request string
@ -373,3 +387,14 @@ func sendSignal(signal string, ch *channel) {
}
ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
}
func sendInvalidRecord(ch *channel) {
defer ch.Close()
packet := make([]byte, 1+4+4+1)
packet[0] = msgChannelData
marshalUint32(packet[1:], 29348723 /* invalid channel id */)
marshalUint32(packet[5:], 1)
packet[9] = 42
ch.serverConn.writePacket(packet)
}