Skip to content

Commit

Permalink
quic: validate connection id transport parameters
Browse files Browse the repository at this point in the history
Validate the original_destination_connection_id and
initial_source_connection_id transport parameters.

RFC 9000, Section 7.3

For golang/go#58547

Change-Id: I8343fd53c5cc946f15d3410c632b3895205fd597
Reviewed-on: https://go-review.googlesource.com/c/net/+/530036
Reviewed-by: Jonathan Amsterdam <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
neild committed Oct 3, 2023
1 parent a600b35 commit 21814e7
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
8 changes: 7 additions & 1 deletion internal/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
// non-blocking operation.
c.msgc = make(chan any, 1)

var originalDstConnID []byte
if c.side == clientSide {
if err := c.connIDState.initClient(c); err != nil {
return nil, err
Expand All @@ -95,6 +96,7 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
if err := c.connIDState.initServer(c, initialConnID); err != nil {
return nil, err
}
originalDstConnID = initialConnID
}

// The smallest allowed maximum QUIC datagram size is 1200 bytes.
Expand All @@ -105,9 +107,10 @@ func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.
c.streamsInit()
c.lifetimeInit()

// TODO: initial_source_connection_id, retry_source_connection_id
// TODO: retry_source_connection_id
if err := c.startTLS(now, initialConnID, transportParameters{
initialSrcConnID: c.connIDState.srcConnID(),
originalDstConnID: originalDstConnID,
ackDelayExponent: ackDelayExponent,
maxUDPPayloadSize: maxUDPPayloadSize,
maxAckDelay: maxAckDelay,
Expand Down Expand Up @@ -171,6 +174,9 @@ func (c *Conn) discardKeys(now time.Time, space numberSpace) {

// receiveTransportParameters applies transport parameters sent by the peer.
func (c *Conn) receiveTransportParameters(p transportParameters) error {
if err := c.connIDState.validateTransportParameters(c.side, p); err != nil {
return err
}
c.streams.outflow.setMaxData(p.initialMaxData)
c.streams.localLimit[bidiStream].setMax(p.initialMaxStreamsBidi)
c.streams.localLimit[uniStream].setMax(p.initialMaxStreamsUni)
Expand Down
44 changes: 40 additions & 4 deletions internal/quic/conn_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,39 @@ func (s *connIDState) issueLocalIDs(c *Conn) error {
return nil
}

// validateTransportParameters verifies the original_destination_connection_id and
// initial_source_connection_id transport parameters match the expected values.
func (s *connIDState) validateTransportParameters(side connSide, p transportParameters) error {
// TODO: Consider returning more detailed errors, for debugging.
switch side {
case clientSide:
// Verify original_destination_connection_id matches
// the transient remote connection ID we chose.
if len(s.remote) == 0 || s.remote[0].seq != -1 {
return localTransportError(errInternal)
}
if !bytes.Equal(s.remote[0].cid, p.originalDstConnID) {
return localTransportError(errTransportParameter)
}
// Remove the transient remote connection ID.
// We have no further need for it.
s.remote = append(s.remote[:0], s.remote[1:]...)
case serverSide:
if p.originalDstConnID != nil {
// Clients do not send original_destination_connection_id.
return localTransportError(errTransportParameter)
}
}
// Verify initial_source_connection_id matches the first remote connection ID.
if len(s.remote) == 0 || s.remote[0].seq != 0 {
return localTransportError(errInternal)
}
if !bytes.Equal(p.initialSrcConnID, s.remote[0].cid) {
return localTransportError(errTransportParameter)
}
return nil
}

// handlePacket updates the connection ID state during the handshake
// (Initial and Handshake packets).
func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) {
Expand All @@ -170,10 +203,13 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
// We're a client connection processing the first Initial packet
// from the server. Replace the transient remote connection ID
// with the Source Connection ID from the packet.
s.remote[0] = connID{
// Leave the transient ID the list for now, since we'll need it when
// processing the transport parameters.
s.remote[0].retired = true
s.remote = append(s.remote, connID{
seq: 0,
cid: cloneBytes(srcConnID),
}
})
}
case ptype == packetTypeInitial && c.side == serverSide:
if len(s.remote) == 0 {
Expand All @@ -185,7 +221,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
})
}
case ptype == packetTypeHandshake && c.side == serverSide:
if len(s.local) > 0 && s.local[0].seq == -1 {
if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
Expand Down Expand Up @@ -213,7 +249,7 @@ func (s *connIDState) handleNewConnID(seq, retire int64, cid []byte, resetToken
active := 0
for i := range s.remote {
rcid := &s.remote[i]
if !rcid.retired && rcid.seq < s.retireRemotePriorTo {
if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
s.retireRemote(rcid)
}
if !rcid.retired {
Expand Down
38 changes: 36 additions & 2 deletions internal/quic/conn_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func TestConnIDClientHandshake(t *testing.T) {
t.Errorf("local ids: %v, want %v", fmtConnIDList(got), fmtConnIDList(wantLocal))
}
wantRemote := []connID{{
cid: testLocalConnID(-1),
seq: -1,
}, {
cid: testPeerConnID(0),
seq: 0,
}}
Expand Down Expand Up @@ -261,10 +264,12 @@ func TestConnIDPeerRetiresConnID(t *testing.T) {
}

func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) {
// An endpoint that selects a zero-length connection ID during the handshake
// "An endpoint that selects a zero-length connection ID during the handshake
// cannot issue a new connection ID."
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8
tc := newTestConn(t, clientSide)
tc := newTestConn(t, clientSide, func(p *transportParameters) {
p.initialSrcConnID = []byte{}
})
tc.peerConnID = []byte{}
tc.ignoreFrame(frameTypeAck)
tc.uncheckedHandshake()
Expand Down Expand Up @@ -536,6 +541,7 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) {
// Peer gives us more conn ids than our advertised limit,
// including a conn id in the preferred address transport parameter.
tc := newTestConn(t, serverSide, func(p *transportParameters) {
p.initialSrcConnID = []byte{}
p.preferredAddrV4 = netip.MustParseAddrPort("0.0.0.0:0")
p.preferredAddrV6 = netip.MustParseAddrPort("[::0]:0")
p.preferredAddrConnID = testPeerConnID(1)
Expand All @@ -552,3 +558,31 @@ func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) {
code: errProtocolViolation,
})
}

func TestConnIDInitialSrcConnIDMismatch(t *testing.T) {
// "Endpoints MUST validate that received [initial_source_connection_id]
// parameters match received connection ID values."
// https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3
testSides(t, "", func(t *testing.T, side connSide) {
tc := newTestConn(t, side, func(p *transportParameters) {
p.initialSrcConnID = []byte("invalid")
})
tc.ignoreFrame(frameTypeAck)
tc.ignoreFrame(frameTypeCrypto)
tc.writeFrames(packetTypeInitial,
debugFrameCrypto{
data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
})
if side == clientSide {
// Server transport parameters are carried in the Handshake packet.
tc.writeFrames(packetTypeHandshake,
debugFrameCrypto{
data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
})
}
tc.wantFrame("initial_source_connection_id transport parameter mismatch",
packetTypeInitial, debugFrameConnectionCloseTransport{
code: errTransportParameter,
})
})
}
4 changes: 4 additions & 0 deletions internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn {
TLSConfig: newTestTLSConfig(side),
}
peerProvidedParams := defaultTransportParameters()
peerProvidedParams.initialSrcConnID = testPeerConnID(0)
if side == clientSide {
peerProvidedParams.originalDstConnID = testLocalConnID(-1)
}
for _, o := range opts {
switch o := o.(type) {
case func(*Config):
Expand Down

0 comments on commit 21814e7

Please sign in to comment.