diff --git a/association.go b/association.go index 3b928820..75b26397 100644 --- a/association.go +++ b/association.go @@ -99,6 +99,12 @@ const ( // other constants const ( acceptChSize = 16 + // maxTSNOffset is the maximum offset of a received chunk TSN from the cummulative TSN + // we have seen so far that we will enqueue. + // For a chunk to be enqueued chunk.tsn < cummulativeTSN + maxTSNOffset + // This allows us to not enqueue too many bytes over the receive window in case of out + // of order delivery. A buffer of 1000 TSNs implies an excess of roughly 2MB. + maxTSNOffset = 2000 ) func getAssociationStateString(a uint32) string { diff --git a/association_test.go b/association_test.go index 8a50cf29..2c8681e1 100644 --- a/association_test.go +++ b/association_test.go @@ -19,6 +19,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -2731,6 +2732,153 @@ loop: return a1, a2, nil } +// udpDiscardReader blocks all reads after block is set to true. +// This allows us to send arbitrary packets on a stream and block the packets received in response +type udpDiscardReader struct { + net.Conn + ctx context.Context + block atomic.Bool +} + +func (d *udpDiscardReader) Read(b []byte) (n int, err error) { + if d.block.Load() { + <-d.ctx.Done() + return 0, d.ctx.Err() + } + return d.Conn.Read(b) +} + +func TestAssociationReceiveWindow(t *testing.T) { + udp1, udp2 := createUDPConnPair() + ctx, cancel := context.WithCancel(context.Background()) + dudp1 := &udpDiscardReader{Conn: udp1, ctx: ctx} + createAssociations := func() (*Association, *Association, error) { + loggerFactory := logging.NewDefaultLoggerFactory() + + a1Chan := make(chan interface{}) + a2Chan := make(chan interface{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + a, err2 := createClientWithContext(ctx, Config{ + NetConn: dudp1, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + a1Chan <- err2 + } else { + a1Chan <- a + } + }() + + go func() { + a, err2 := createClientWithContext(ctx, Config{ + NetConn: udp2, + LoggerFactory: loggerFactory, + MaxReceiveBufferSize: 100_000, + }) + if err2 != nil { + a2Chan <- err2 + } else { + a2Chan <- a + } + }() + + var a1 *Association + var a2 *Association + + loop: + for { + select { + case v1 := <-a1Chan: + switch v := v1.(type) { + case *Association: + a1 = v + if a2 != nil { + break loop + } + case error: + return nil, nil, v + } + case v2 := <-a2Chan: + switch v := v2.(type) { + case *Association: + a2 = v + if a1 != nil { + break loop + } + case error: + return nil, nil, v + } + } + } + return a1, a2, nil + } + // a1 is the association used for sending data + // a2 is the association with receive window of 100kB which we will + // try to bypass + a1, a2, err := createAssociations() + + require.NoError(t, err) + defer a2.Close() + defer a1.Close() + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer s1.Close() + s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + dudp1.block.Store(true) + + s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + s2, err := a2.AcceptStream() + require.NoError(t, err) + require.Equal(t, uint16(1), s2.streamIdentifier) + + done := make(chan bool) + go func() { + chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks = chunks[:1] + chunk := chunks[0] + // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue + chunk.tsn = a1.myNextTSN + 1e9 + for chunk.tsn > a1.myNextTSN { + select { + case <-done: + return + default: + } + chunk.tsn -= 1 + pp := a1.bundleDataChunksIntoPackets(chunks) + for _, p := range pp { + raw, err := p.marshal(true) + if err != nil { + return + } + _, err = a1.netConn.Write(raw) + if err != nil { + return + } + } + if chunk.tsn%10 == 0 { + time.Sleep(10 * time.Millisecond) + } + } + }() + + for cnt := 0; cnt < 15; cnt++ { + bytesQueued := s2.getNumBytesInReassemblyQueue() + if bytesQueued > 5_000_000 { + t.Error("too many bytes enqueued with receive window of 10kb", bytesQueued) + break + } + t.Log("bytes queued", bytesQueued) + time.Sleep(1 * time.Second) + } + close(done) + cancel() +} + func TestAssociation_Shutdown(t *testing.T) { checkGoroutineLeaks(t) diff --git a/payload_queue.go b/payload_queue.go index e5925a51..1510bb6f 100644 --- a/payload_queue.go +++ b/payload_queue.go @@ -38,7 +38,7 @@ func (q *payloadQueue) updateSortedKeys() { func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { _, ok := q.chunkMap[p.tsn] - if ok || sna32LTE(p.tsn, cumulativeTSN) { + if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) { return false } return true