credentials/alts: Properly release server InBytes buffer after the handshake is complete. (#3529)
This commit is contained in:
Родитель
759569bb9c
Коммит
f9ac13d469
|
@ -111,6 +111,7 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
|
|||
}
|
||||
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
|
||||
payloadLengthLimit := altsRecordDefaultLength - overhead
|
||||
var protectedBuf []byte
|
||||
if protected == nil {
|
||||
// We pre-allocate protected to be of size
|
||||
// 2*altsRecordDefaultLength-1 during initialization. We only
|
||||
|
@ -120,16 +121,19 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
|
|||
// altsRecordDefaultLength (bytes) data into protected at one
|
||||
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
|
||||
// to buffer data read from the network.
|
||||
protected = make([]byte, 0, 2*altsRecordDefaultLength-1)
|
||||
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
|
||||
} else {
|
||||
protectedBuf = make([]byte, len(protected))
|
||||
copy(protectedBuf, protected)
|
||||
}
|
||||
|
||||
altsConn := &conn{
|
||||
Conn: c,
|
||||
crypto: crypto,
|
||||
payloadLengthLimit: payloadLengthLimit,
|
||||
protected: protected,
|
||||
protected: protectedBuf,
|
||||
writeBuf: make([]byte, altsWriteBufferInitialSize),
|
||||
nextFrame: protected,
|
||||
nextFrame: protectedBuf,
|
||||
overhead: overhead,
|
||||
}
|
||||
return altsConn, nil
|
||||
|
|
|
@ -77,7 +77,7 @@ func (c *testConn) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn {
|
||||
func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string, protected []byte) *conn {
|
||||
key := []byte{
|
||||
// 16 arbitrary bytes.
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
||||
|
@ -85,23 +85,23 @@ func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *co
|
|||
in: in,
|
||||
out: out,
|
||||
}
|
||||
c, err := NewConn(&tc, side, np, key, nil)
|
||||
c, err := NewConn(&tc, side, np, key, protected)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
|
||||
}
|
||||
return c.(*conn)
|
||||
}
|
||||
|
||||
func newConnPair(np string) (client, server *conn) {
|
||||
func newConnPair(np string, clientProtected []byte, serverProtected []byte) (client, server *conn) {
|
||||
clientBuf := new(bytes.Buffer)
|
||||
serverBuf := new(bytes.Buffer)
|
||||
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np)
|
||||
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np)
|
||||
clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np, clientProtected)
|
||||
serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np, serverProtected)
|
||||
return clientConn, serverConn
|
||||
}
|
||||
|
||||
func testPingPong(t *testing.T, np string) {
|
||||
clientConn, serverConn := newConnPair(np)
|
||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
||||
clientMsg := []byte("Client Message")
|
||||
if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
|
||||
t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
|
||||
|
@ -134,7 +134,7 @@ func (s) TestPingPong(t *testing.T) {
|
|||
}
|
||||
|
||||
func testSmallReadBuffer(t *testing.T, np string) {
|
||||
clientConn, serverConn := newConnPair(np)
|
||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
||||
msg := []byte("Very Important Message")
|
||||
if n, err := clientConn.Write(msg); err != nil {
|
||||
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||
|
@ -161,7 +161,7 @@ func (s) TestSmallReadBuffer(t *testing.T) {
|
|||
}
|
||||
|
||||
func testLargeMsg(t *testing.T, np string) {
|
||||
clientConn, serverConn := newConnPair(np)
|
||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
||||
// msgLen is such that the length in the framing is larger than the
|
||||
// default size of one frame.
|
||||
msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
||||
|
@ -193,7 +193,7 @@ func testIncorrectMsgType(t *testing.T, np string) {
|
|||
binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
|
||||
|
||||
in := bytes.NewBuffer(framedMsg)
|
||||
c := newTestALTSRecordConn(in, nil, core.ClientSide, np)
|
||||
c := newTestALTSRecordConn(in, nil, core.ClientSide, np, nil)
|
||||
b := make([]byte, 1)
|
||||
if n, err := c.Read(b); n != 0 || err == nil {
|
||||
t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
|
||||
|
@ -208,8 +208,8 @@ func (s) TestIncorrectMsgType(t *testing.T) {
|
|||
|
||||
func testFrameTooLarge(t *testing.T, np string) {
|
||||
buf := new(bytes.Buffer)
|
||||
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np)
|
||||
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np)
|
||||
clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np, nil)
|
||||
serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np, nil)
|
||||
// payloadLen is such that the length in the framing is larger than
|
||||
// allowed in one frame.
|
||||
payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
|
||||
|
@ -242,7 +242,7 @@ func (s) TestFrameTooLarge(t *testing.T) {
|
|||
func testWriteLargeData(t *testing.T, np string) {
|
||||
// Test sending and receiving messages larger than the maximum write
|
||||
// buffer size.
|
||||
clientConn, serverConn := newConnPair(np)
|
||||
clientConn, serverConn := newConnPair(np, nil, nil)
|
||||
// Message size is intentionally chosen to not be multiple of
|
||||
// payloadLengthLimtit.
|
||||
msgSize := altsWriteBufferMaxSize + (100 * 1024)
|
||||
|
@ -281,3 +281,44 @@ func (s) TestWriteLargeData(t *testing.T) {
|
|||
testWriteLargeData(t, np)
|
||||
}
|
||||
}
|
||||
|
||||
func testProtectedBuffer(t *testing.T, np string) {
|
||||
key := []byte{
|
||||
// 16 arbitrary bytes.
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
|
||||
|
||||
// Encrypt a message to be passed to NewConn as a client-side protected
|
||||
// buffer.
|
||||
newCrypto := protocols[np]
|
||||
if newCrypto == nil {
|
||||
t.Fatalf("Unknown next protocol %q", np)
|
||||
}
|
||||
crypto, err := newCrypto(core.ClientSide, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create a crypter for protocol %q: %v", np, err)
|
||||
}
|
||||
msg := []byte("Client Protected Message")
|
||||
encryptedMsg, err := crypto.Encrypt(nil, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt the client protected message: %v", err)
|
||||
}
|
||||
protectedMsg := make([]byte, 8) // 8 bytes = 4 length + 4 type
|
||||
binary.LittleEndian.PutUint32(protectedMsg, uint32(len(encryptedMsg))+4) // 4 bytes for the type
|
||||
binary.LittleEndian.PutUint32(protectedMsg[4:], altsRecordMsgType)
|
||||
protectedMsg = append(protectedMsg, encryptedMsg...)
|
||||
|
||||
_, serverConn := newConnPair(np, nil, protectedMsg)
|
||||
rcvClientMsg := make([]byte, len(msg))
|
||||
if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
|
||||
t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
|
||||
}
|
||||
if !reflect.DeepEqual(msg, rcvClientMsg) {
|
||||
t.Fatalf("Client protected/Server Read() = %v, want %v", rcvClientMsg, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestProtectedBuffer(t *testing.T) {
|
||||
for _, np := range nextProtocols {
|
||||
testProtectedBuffer(t, np)
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче