Each side of a QUIC connection chooses the connection IDs used by
its peer. In our case, we use 8-byte random IDs.

A connection has a list of connection IDs that it may receive
packets on, and a list that it may send packets to. Add a minimal
data structure for tracking these lists, and handling of the
connection IDs tracked across Initial and Handshake packets.

This does not yet handle post-handshake connection ID changes
made in NEW_CONNECTION_ID and RETIRE_CONNECTION_ID frames.

RFC 9000, Section 5.1.

For golang/go#58547

Change-Id: I3e059393cacafbcea04a1b4131c0c7dc28acad5e
Reviewed-on: https://go-review.googlesource.com/c/net/+/506675
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Damien Neil 2022-10-13 12:09:20 -07:00
Родитель 304cc91b19
Коммит 57553cbff1
2 изменённых файлов: 256 добавлений и 0 удалений

147
internal/quic/conn_id.go Normal file
Просмотреть файл

@ -0,0 +1,147 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"crypto/rand"
)
// connIDState is a conn's connection IDs.
type connIDState struct {
// The destination connection IDs of packets we receive are local.
// The destination connection IDs of packets we send are remote.
//
// Local IDs are usually issued by us, and remote IDs by the peer.
// The exception is the transient destination connection ID sent in
// a client's Initial packets, which is chosen by the client.
local []connID
remote []connID
}
// A connID is a connection ID and associated metadata.
type connID struct {
// cid is the connection ID itself.
cid []byte
// seq is the connection ID's sequence number:
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-1
//
// For the transient destination ID in a client's Initial packet, this is -1.
seq int64
}
func (s *connIDState) initClient(newID newConnIDFunc) error {
// Client chooses its initial connection ID, and sends it
// in the Source Connection ID field of the first Initial packet.
locid, err := newID()
if err != nil {
return err
}
s.local = append(s.local, connID{
seq: 0,
cid: locid,
})
// Client chooses an initial, transient connection ID for the server,
// and sends it in the Destination Connection ID field of the first Initial packet.
remid, err := newID()
if err != nil {
return err
}
s.remote = append(s.remote, connID{
seq: -1,
cid: remid,
})
return nil
}
func (s *connIDState) initServer(newID newConnIDFunc, dstConnID []byte) error {
// Client-chosen, transient connection ID received in the first Initial packet.
// The server will not use this as the Source Connection ID of packets it sends,
// but remembers it because it may receive packets sent to this destination.
s.local = append(s.local, connID{
seq: -1,
cid: cloneBytes(dstConnID),
})
// Server chooses a connection ID, and sends it in the Source Connection ID of
// the response to the clent.
locid, err := newID()
if err != nil {
return err
}
s.local = append(s.local, connID{
seq: 0,
cid: locid,
})
return nil
}
// srcConnID is the Source Connection ID to use in a sent packet.
func (s *connIDState) srcConnID() []byte {
if s.local[0].seq == -1 && len(s.local) > 1 {
// Don't use the transient connection ID if another is available.
return s.local[1].cid
}
return s.local[0].cid
}
// dstConnID is the Destination Connection ID to use in a sent packet.
func (s *connIDState) dstConnID() []byte {
return s.remote[0].cid
}
// handlePacket updates the connection ID state during the handshake
// (Initial and Handshake packets).
func (s *connIDState) handlePacket(side connSide, ptype packetType, srcConnID []byte) {
switch {
case ptype == packetTypeInitial && side == clientSide:
if len(s.remote) == 1 && s.remote[0].seq == -1 {
// We're a client connection processing the first Initial packet
// from the server. Replace the transient remote connection ID
// with the Source Connection ID from the packet.
s.remote[0] = connID{
seq: 0,
cid: cloneBytes(srcConnID),
}
}
case ptype == packetTypeInitial && side == serverSide:
if len(s.remote) == 0 {
// We're a server connection processing the first Initial packet
// from the client. Set the client's connection ID.
s.remote = append(s.remote, connID{
seq: 0,
cid: cloneBytes(srcConnID),
})
}
case ptype == packetTypeHandshake && side == serverSide:
if len(s.local) > 0 && s.local[0].seq == -1 {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
s.local = append(s.local[:0], s.local[1:]...)
}
}
}
func cloneBytes(b []byte) []byte {
n := make([]byte, len(b))
copy(n, b)
return n
}
type newConnIDFunc func() ([]byte, error)
func newRandomConnID() ([]byte, error) {
// It is not necessary for connection IDs to be cryptographically secure,
// but it doesn't hurt.
id := make([]byte, connIDLen)
if _, err := rand.Read(id); err != nil {
return nil, err
}
return id, nil
}

Просмотреть файл

@ -0,0 +1,109 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.21
package quic
import (
"fmt"
"reflect"
"testing"
)
func TestConnIDClientHandshake(t *testing.T) {
// On initialization, the client chooses local and remote IDs.
//
// The order in which we allocate the two isn't actually important,
// but test is a lot simpler if we assume.
var s connIDState
s.initClient(newConnIDSequence())
if got, want := string(s.srcConnID()), "local-1"; got != want {
t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
}
if got, want := string(s.dstConnID()), "local-2"; got != want {
t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
}
// The server's first Initial packet provides the client with a
// non-transient remote connection ID.
s.handlePacket(clientSide, packetTypeInitial, []byte("remote-1"))
if got, want := string(s.dstConnID()), "remote-1"; got != want {
t.Errorf("after receiving Initial: dstConnID = %q, want %q", got, want)
}
wantLocal := []connID{{
cid: []byte("local-1"),
seq: 0,
}}
if !reflect.DeepEqual(s.local, wantLocal) {
t.Errorf("local ids: %v, want %v", s.local, wantLocal)
}
wantRemote := []connID{{
cid: []byte("remote-1"),
seq: 0,
}}
if !reflect.DeepEqual(s.remote, wantRemote) {
t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
}
}
func TestConnIDServerHandshake(t *testing.T) {
// On initialization, the server is provided with the client-chosen
// transient connection ID, and allocates an ID of its own.
// The Initial packet sets the remote connection ID.
var s connIDState
s.initServer(newConnIDSequence(), []byte("transient"))
s.handlePacket(serverSide, packetTypeInitial, []byte("remote-1"))
if got, want := string(s.srcConnID()), "local-1"; got != want {
t.Errorf("after initClient: srcConnID = %q, want %q", got, want)
}
if got, want := string(s.dstConnID()), "remote-1"; got != want {
t.Errorf("after initClient: dstConnID = %q, want %q", got, want)
}
wantLocal := []connID{{
cid: []byte("transient"),
seq: -1,
}, {
cid: []byte("local-1"),
seq: 0,
}}
if !reflect.DeepEqual(s.local, wantLocal) {
t.Errorf("local ids: %v, want %v", s.local, wantLocal)
}
wantRemote := []connID{{
cid: []byte("remote-1"),
seq: 0,
}}
if !reflect.DeepEqual(s.remote, wantRemote) {
t.Errorf("remote ids: %v, want %v", s.remote, wantRemote)
}
// The client's first Handshake packet permits the server to discard the
// transient connection ID.
s.handlePacket(serverSide, packetTypeHandshake, []byte("remote-1"))
wantLocal = []connID{{
cid: []byte("local-1"),
seq: 0,
}}
if !reflect.DeepEqual(s.local, wantLocal) {
t.Errorf("after handshake local ids: %v, want %v", s.local, wantLocal)
}
}
func newConnIDSequence() newConnIDFunc {
var n uint64
return func() ([]byte, error) {
n++
return []byte(fmt.Sprintf("local-%v", n)), nil
}
}
func TestNewRandomConnID(t *testing.T) {
cid, err := newRandomConnID()
if len(cid) != connIDLen || err != nil {
t.Fatalf("newConnID() = %x, %v; want %v bytes", cid, connIDLen, err)
}
}