diff --git a/ssh/messages.go b/ssh/messages.go index e5d76b01..94c3ea03 100644 --- a/ssh/messages.go +++ b/ssh/messages.go @@ -227,17 +227,20 @@ type userAuthPubKeyOkMsg struct { PubKey string } -// unmarshal parses the SSH wire data in packet into out using reflection. -// expectedType is the expected SSH message type. It either returns nil on -// success, or a ParseError or UnexpectedMessageError on error. +// unmarshal parses the SSH wire data in packet into out using +// reflection. expectedType, if non-zero, is the SSH message type that +// the packet is expected to start with. unmarshal either returns nil +// on success, or a ParseError or UnexpectedMessageError on error. func unmarshal(out interface{}, packet []byte, expectedType uint8) error { if len(packet) == 0 { return ParseError{expectedType} } - if packet[0] != expectedType { - return UnexpectedMessageError{expectedType, packet[0]} + if expectedType > 0 { + if packet[0] != expectedType { + return UnexpectedMessageError{expectedType, packet[0]} + } + packet = packet[1:] } - packet = packet[1:] v := reflect.ValueOf(out).Elem() structType := v.Type() @@ -319,10 +322,13 @@ func unmarshal(out interface{}, packet []byte, expectedType uint8) error { return nil } -// marshal serializes the message in msg, using the given message type. +// marshal serializes the message in msg. The given message type is +// prepended if it is non-zero. func marshal(msgType uint8, msg interface{}) []byte { - out := make([]byte, 1, 64) - out[0] = msgType + out := make([]byte, 0, 64) + if msgType > 0 { + out = append(out, msgType) + } v := reflect.ValueOf(msg) for i, n := 0, v.NumField(); i < n; i++ { diff --git a/ssh/messages_test.go b/ssh/messages_test.go index fb944041..fd86b62f 100644 --- a/ssh/messages_test.go +++ b/ssh/messages_test.go @@ -78,6 +78,38 @@ func TestMarshalUnmarshal(t *testing.T) { } } +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := marshal(0, s) + roundtrip := S{} + unmarshal(&roundtrip, packet, 0) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := marshal(0, s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + func randomBytes(out []byte, rand *rand.Rand) { for i := 0; i < len(out); i++ { out[i] = byte(rand.Int31())