diff --git a/association.go b/association.go index 75b26397..a5310d6f 100644 --- a/association.go +++ b/association.go @@ -99,12 +99,17 @@ const ( // other constants const ( acceptChSize = 16 - // maxTSNOffset is the maximum offset of a received chunk TSN from the cummulative TSN - // we have seen so far that we will enqueue. - // For a chunk to be enqueued chunk.tsn < cummulativeTSN + maxTSNOffset - // This allows us to not enqueue too many bytes over the receive window in case of out - // of order delivery. A buffer of 1000 TSNs implies an excess of roughly 2MB. - maxTSNOffset = 2000 + // avgChunkSize is an estimate of the average chunk size. There is no theory behind + // this estimate. + avgChunkSize = 500 + // minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue + // irrespective of the receive buffer size + // see Association.getMaxTSNOffset + minTSNOffset = 2000 + // maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue + // irrespective of the receive buffer size + // see Association.getMaxTSNOffset + maxTSNOffset = 40000 ) func getAssociationStateString(a uint32) string { @@ -1116,6 +1121,23 @@ func (a *Association) SRTT() float64 { return a.srtt.Load().(float64) //nolint:forcetypeassert } +// getMaxTSNOffset returns the maximum offset over the current cummulative TSN that +// we are willing to enqueue. Limiting the maximum offset limits the number of +// tsns we have in the payloadQueue map. This ensures that we don't use too much space in +// the map itself. This also ensures that we keep the bytes utilised in the receive +// buffer within a small multiple of the user provided max receive buffer size. +func (a *Association) getMaxTSNOffset() uint32 { + // 4 is a magic number here. There is no theory behind this. + offset := (a.maxReceiveBufferSize * 4) / avgChunkSize + if offset < minTSNOffset { + offset = minTSNOffset + } + if offset > maxTSNOffset { + offset = maxTSNOffset + } + return offset +} + func setSupportedExtensions(init *chunkInitCommon) { // nolint:godox // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 @@ -1384,7 +1406,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet { a.name, d.tsn, d.immediateSack, len(d.userData)) a.stats.incDATAs() - canPush := a.payloadQueue.canPush(d, a.peerLastTSN) + canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset()) if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { diff --git a/association_test.go b/association_test.go index 2c8681e1..89c25e24 100644 --- a/association_test.go +++ b/association_test.go @@ -2879,6 +2879,122 @@ func TestAssociationReceiveWindow(t *testing.T) { cancel() } +func TestAssociationMaxTSNOffset(t *testing.T) { + udp1, udp2 := createUDPConnPair() + createAssociations := func() (*Association, *Association, error) { + loggerFactory := logging.NewDefaultLoggerFactory() + + a1Chan := make(chan interface{}) + a2Chan := make(chan interface{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + a, err2 := createClientWithContext(ctx, Config{ + NetConn: udp1, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + a1Chan <- err2 + } else { + a1Chan <- a + } + }() + + go func() { + a, err2 := createClientWithContext(ctx, Config{ + NetConn: udp2, + LoggerFactory: loggerFactory, + MaxReceiveBufferSize: 100_000, + }) + if err2 != nil { + a2Chan <- err2 + } else { + a2Chan <- a + } + }() + + var a1 *Association + var a2 *Association + + loop: + for { + select { + case v1 := <-a1Chan: + switch v := v1.(type) { + case *Association: + a1 = v + if a2 != nil { + break loop + } + case error: + return nil, nil, v + } + case v2 := <-a2Chan: + switch v := v2.(type) { + case *Association: + a2 = v + if a1 != nil { + break loop + } + case error: + return nil, nil, v + } + } + } + return a1, a2, nil + } + // a1 is the association used for sending data + // a2 is the association with receive window of 100kB which we will + // try to bypass + a1, a2, err := createAssociations() + + require.NoError(t, err) + defer a2.Close() + defer a1.Close() + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer s1.Close() + s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + s2, err := a2.AcceptStream() + require.NoError(t, err) + require.Equal(t, uint16(1), s2.streamIdentifier) + + chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks = chunks[:1] + sendChunk := func(tsn uint32) { + chunk := chunks[0] + // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue + chunk.tsn = tsn + pp := a1.bundleDataChunksIntoPackets(chunks) + for _, p := range pp { + raw, err := p.marshal(true) + if err != nil { + t.Fatal(err) + return + } + _, err = a1.netConn.Write(raw) + if err != nil { + t.Fatal(err) + return + } + } + } + sendChunk(a1.myNextTSN + 100_000) + time.Sleep(100 * time.Millisecond) + require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000) + + sendChunk(a1.myNextTSN + 10_000) + time.Sleep(100 * time.Millisecond) + require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000) + + sendChunk(a1.myNextTSN + minTSNOffset - 100) + time.Sleep(100 * time.Millisecond) + require.Greater(t, s2.getNumBytesInReassemblyQueue(), 1000) +} + func TestAssociation_Shutdown(t *testing.T) { checkGoroutineLeaks(t) diff --git a/payload_queue.go b/payload_queue.go index 1510bb6f..a0b1b26f 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -36,7 +36,7 @@ func (q *payloadQueue) updateSortedKeys() { }) } -func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { +func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool { _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) { return false