From f9ac13d4698ab7feb6e4f86165a2d9e74430e70b Mon Sep 17 00:00:00 2001 From: Cesar Ghali Date: Thu, 16 Apr 2020 15:09:15 -0700 Subject: [PATCH] credentials/alts: Properly release server InBytes buffer after the handshake is complete. (#3529) --- credentials/alts/internal/conn/record.go | 10 ++- credentials/alts/internal/conn/record_test.go | 65 +++++++++++++++---- 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/credentials/alts/internal/conn/record.go b/credentials/alts/internal/conn/record.go index fd5a53d9..8a872c3c 100644 --- a/credentials/alts/internal/conn/record.go +++ b/credentials/alts/internal/conn/record.go @@ -111,6 +111,7 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot } overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead() payloadLengthLimit := altsRecordDefaultLength - overhead + var protectedBuf []byte if protected == nil { // We pre-allocate protected to be of size // 2*altsRecordDefaultLength-1 during initialization. We only @@ -120,16 +121,19 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot // altsRecordDefaultLength (bytes) data into protected at one // time. Therefore, 2*altsRecordDefaultLength-1 is large enough // to buffer data read from the network. - protected = make([]byte, 0, 2*altsRecordDefaultLength-1) + protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1) + } else { + protectedBuf = make([]byte, len(protected)) + copy(protectedBuf, protected) } altsConn := &conn{ Conn: c, crypto: crypto, payloadLengthLimit: payloadLengthLimit, - protected: protected, + protected: protectedBuf, writeBuf: make([]byte, altsWriteBufferInitialSize), - nextFrame: protected, + nextFrame: protectedBuf, overhead: overhead, } return altsConn, nil diff --git a/credentials/alts/internal/conn/record_test.go b/credentials/alts/internal/conn/record_test.go index af3bc9b6..59d4f41e 100644 --- a/credentials/alts/internal/conn/record_test.go +++ b/credentials/alts/internal/conn/record_test.go @@ -77,7 +77,7 @@ func (c *testConn) Close() error { return nil } -func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn { +func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, protected []byte) *conn { key := []byte{ // 16 arbitrary bytes. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} @@ -85,23 +85,23 @@ func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *co in: in, out: out, } - c, err := NewConn(&tc, side, np, key, nil) + c, err := NewConn(&tc, side, np, key, protected) if err != nil { panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err)) } return c.(*conn) } -func newConnPair(np string) (client, server *conn) { +func newConnPair(np string, clientProtected []byte, serverProtected []byte) (client, server *conn) { clientBuf := new(bytes.Buffer) serverBuf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np) - serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np) + clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np, clientProtected) + serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np, serverProtected) return clientConn, serverConn } func testPingPong(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np) + clientConn, serverConn := newConnPair(np, nil, nil) clientMsg := []byte("Client Message") if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil { t.Fatalf("Client Write() = %v, %v; want %v, ", n, err, len(clientMsg)) @@ -134,7 +134,7 @@ func (s) TestPingPong(t *testing.T) { } func testSmallReadBuffer(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np) + clientConn, serverConn := newConnPair(np, nil, nil) msg := []byte("Very Important Message") if n, err := clientConn.Write(msg); err != nil { t.Fatalf("Write() = %v, %v; want %v, ", n, err, len(msg)) @@ -161,7 +161,7 @@ func (s) TestSmallReadBuffer(t *testing.T) { } func testLargeMsg(t *testing.T, np string) { - clientConn, serverConn := newConnPair(np) + clientConn, serverConn := newConnPair(np, nil, nil) // msgLen is such that the length in the framing is larger than the // default size of one frame. msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -193,7 +193,7 @@ func testIncorrectMsgType(t *testing.T, np string) { binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType) in := bytes.NewBuffer(framedMsg) - c := newTestALTSRecordConn(in, nil, core.ClientSide, np) + c := newTestALTSRecordConn(in, nil, core.ClientSide, np, nil) b := make([]byte, 1) if n, err := c.Read(b); n != 0 || err == nil { t.Fatalf("Read() = , want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType)) @@ -208,8 +208,8 @@ func (s) TestIncorrectMsgType(t *testing.T) { func testFrameTooLarge(t *testing.T, np string) { buf := new(bytes.Buffer) - clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np) - serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np) + clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np, nil) + serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np, nil) // payloadLen is such that the length in the framing is larger than // allowed in one frame. payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1 @@ -242,7 +242,7 @@ func (s) TestFrameTooLarge(t *testing.T) { func testWriteLargeData(t *testing.T, np string) { // Test sending and receiving messages larger than the maximum write // buffer size. - clientConn, serverConn := newConnPair(np) + clientConn, serverConn := newConnPair(np, nil, nil) // Message size is intentionally chosen to not be multiple of // payloadLengthLimtit. msgSize := altsWriteBufferMaxSize + (100 * 1024) @@ -281,3 +281,44 @@ func (s) TestWriteLargeData(t *testing.T) { testWriteLargeData(t, np) } } + +func testProtectedBuffer(t *testing.T, np string) { + key := []byte{ + // 16 arbitrary bytes. + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49} + + // Encrypt a message to be passed to NewConn as a client-side protected + // buffer. + newCrypto := protocols[np] + if newCrypto == nil { + t.Fatalf("Unknown next protocol %q", np) + } + crypto, err := newCrypto(core.ClientSide, key) + if err != nil { + t.Fatalf("Failed to create a crypter for protocol %q: %v", np, err) + } + msg := []byte("Client Protected Message") + encryptedMsg, err := crypto.Encrypt(nil, msg) + if err != nil { + t.Fatalf("Failed to encrypt the client protected message: %v", err) + } + protectedMsg := make([]byte, 8) // 8 bytes = 4 length + 4 type + binary.LittleEndian.PutUint32(protectedMsg, uint32(len(encryptedMsg))+4) // 4 bytes for the type + binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType) + protectedMsg = append(protectedMsg, encryptedMsg...) + + _, serverConn := newConnPair(np, nil, protectedMsg) + rcvClientMsg := make([]byte, len(msg)) + if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil { + t.Fatalf("Server Read() = %v, %v; want %v, ", n, err, len(rcvClientMsg)) + } + if !reflect.DeepEqual(msg, rcvClientMsg) { + t.Fatalf("Client protected/Server Read() = %v, want %v", rcvClientMsg, msg) + } +} + +func (s) TestProtectedBuffer(t *testing.T) { + for _, np := range nextProtocols { + testProtectedBuffer(t, np) + } +}