Skip to content

Commit

Permalink
Limit maximum tsn queued by the association
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt authored and MarcoPolo committed Apr 3, 2024
1 parent 338310b commit adadf96
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
6 changes: 6 additions & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ const (
// other constants
const (
acceptChSize = 16
// maxTSNOffset is the maximum offset of a received chunk TSN from the cummulative TSN
// we have seen so far that we will enqueue.
// For a chunk to be enqueued chunk.tsn < cummulativeTSN + maxTSNOffset
// This allows us to not enqueue too many bytes over the receive window in case of out
// of order delivery. A buffer of 1000 TSNs implies an excess of roughly 2MB.
maxTSNOffset = 2000
)

func getAssociationStateString(a uint32) string {
Expand Down
148 changes: 148 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -2731,6 +2732,153 @@ loop:
return a1, a2, nil
}

// udpDiscardReader blocks all reads after block is set to true.
// This allows us to send arbitrary packets on a stream and block the packets received in response
type udpDiscardReader struct {
net.Conn
ctx context.Context
block atomic.Bool
}

func (d *udpDiscardReader) Read(b []byte) (n int, err error) {
if d.block.Load() {
<-d.ctx.Done()
return 0, d.ctx.Err()
}
return d.Conn.Read(b)
}

func TestAssociationReceiveWindow(t *testing.T) {
udp1, udp2 := createUDPConnPair()
ctx, cancel := context.WithCancel(context.Background())
dudp1 := &udpDiscardReader{Conn: udp1, ctx: ctx}
createAssociations := func() (*Association, *Association, error) {
loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan interface{})
a2Chan := make(chan interface{})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: dudp1,
LoggerFactory: loggerFactory,
})
if err2 != nil {
a1Chan <- err2
} else {
a1Chan <- a
}
}()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udp2,
LoggerFactory: loggerFactory,
MaxReceiveBufferSize: 100_000,
})
if err2 != nil {
a2Chan <- err2
} else {
a2Chan <- a
}
}()

var a1 *Association
var a2 *Association

loop:
for {
select {
case v1 := <-a1Chan:
switch v := v1.(type) {
case *Association:
a1 = v
if a2 != nil {
break loop
}
case error:
return nil, nil, v
}
case v2 := <-a2Chan:
switch v := v2.(type) {
case *Association:
a2 = v
if a1 != nil {
break loop
}
case error:
return nil, nil, v
}
}
}
return a1, a2, nil
}
// a1 is the association used for sending data
// a2 is the association with receive window of 100kB which we will
// try to bypass
a1, a2, err := createAssociations()

require.NoError(t, err)
defer a2.Close()
defer a1.Close()
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer s1.Close()
s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
dudp1.block.Store(true)

s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
s2, err := a2.AcceptStream()
require.NoError(t, err)
require.Equal(t, uint16(1), s2.streamIdentifier)

done := make(chan bool)
go func() {
chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
chunk := chunks[0]
// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
chunk.tsn = a1.myNextTSN + 1e9
for chunk.tsn > a1.myNextTSN {
select {
case <-done:
return
default:
}
chunk.tsn -= 1
pp := a1.bundleDataChunksIntoPackets(chunks)
for _, p := range pp {
raw, err := p.marshal(true)
if err != nil {
return
}
_, err = a1.netConn.Write(raw)
if err != nil {
return
}
}
if chunk.tsn%10 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
}()

for cnt := 0; cnt < 15; cnt++ {
bytesQueued := s2.getNumBytesInReassemblyQueue()
if bytesQueued > 5_000_000 {
t.Error("too many bytes enqueued with receive window of 10kb", bytesQueued)
break
}
t.Log("bytes queued", bytesQueued)
time.Sleep(1 * time.Second)
}
close(done)
cancel()
}

func TestAssociation_Shutdown(t *testing.T) {
checkGoroutineLeaks(t)

Expand Down
2 changes: 1 addition & 1 deletion payload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (q *payloadQueue) updateSortedKeys() {

func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) {
return false
}
return true
Expand Down

0 comments on commit adadf96

Please sign in to comment.