Skip to content

Commit

Permalink
quic: fix data race in connection close
Browse files Browse the repository at this point in the history
We were failing to hold streamsState.streamsMu when removing
a closed stream from the conn's stream map.

Rework this to remove the mutex entirely.
The only access to the map that isn't on the conn's loop is
during stream creation. Send a message to the loop to
register the stream instead of using a mutex.

Change-Id: I2e87089e87c61a6ade8219dfb8acec3809bf95de
Reviewed-on: https://go-review.googlesource.com/c/net/+/545217
LUCI-TryBot-Result: Go LUCI <[email protected]>
Reviewed-by: Jonathan Amsterdam <[email protected]>
  • Loading branch information
neild committed Dec 18, 2023
1 parent 577e44a commit b952594
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 22 deletions.
31 changes: 28 additions & 3 deletions internal/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,37 @@ func (c *Conn) wake() {
}

// runOnLoop executes a function within the conn's loop goroutine.
func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error {
func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) error {
donec := make(chan struct{})
c.sendMsg(func(now time.Time, c *Conn) {
msg := func(now time.Time, c *Conn) {
defer close(donec)
f(now, c)
})
}
if c.testHooks != nil {
// In tests, we can't rely on being able to send a message immediately:
// c.msgc might be full, and testConnHooks.nextMessage might be waiting
// for us to block before it processes the next message.
// To avoid a deadlock, we send the message in waitUntil.
// If msgc is empty, the message is buffered.
// If msgc is full, we block and let nextMessage process the queue.
msgc := c.msgc
c.testHooks.waitUntil(ctx, func() bool {
for {
select {
case msgc <- msg:
msgc = nil // send msg only once
case <-donec:
return true
case <-c.donec:
return true
default:
return false
}
}
})
} else {
c.sendMsg(msg)
}
select {
case <-donec:
case <-c.donec:
Expand Down
1 change: 1 addition & 0 deletions internal/quic/conn_async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[
})
// Wait for the operation to either finish or block.
<-as.notify
tc.wait()
return a
}

Expand Down
20 changes: 8 additions & 12 deletions internal/quic/conn_streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ import (
)

type streamsState struct {
queue queue[*Stream] // new, peer-created streams

streamsMu sync.Mutex
streams map[streamID]*Stream
queue queue[*Stream] // new, peer-created streams
streams map[streamID]*Stream

// Limits on the number of streams, indexed by streamType.
localLimit [streamTypeCount]localStreamLimits
Expand Down Expand Up @@ -82,9 +80,6 @@ func (c *Conn) NewSendOnlyStream(ctx context.Context) (*Stream, error) {
}

func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, error) {
c.streams.streamsMu.Lock()
defer c.streams.streamsMu.Unlock()

num, err := c.streams.localLimit[styp].open(ctx, c)
if err != nil {
return nil, err
Expand All @@ -100,7 +95,12 @@ func (c *Conn) newLocalStream(ctx context.Context, styp streamType) (*Stream, er
s.inUnlock()
s.outUnlock()

c.streams.streams[s.id] = s
// Modify c.streams on the conn's loop.
if err := c.runOnLoop(ctx, func(now time.Time, c *Conn) {
c.streams.streams[s.id] = s
}); err != nil {
return nil, err
}
return s, nil
}

Expand All @@ -119,8 +119,6 @@ const (
// streamForID returns the stream with the given id.
// If the stream does not exist, it returns nil.
func (c *Conn) streamForID(id streamID) *Stream {
c.streams.streamsMu.Lock()
defer c.streams.streamsMu.Unlock()
return c.streams.streams[id]
}

Expand All @@ -146,8 +144,6 @@ func (c *Conn) streamForFrame(now time.Time, id streamID, ftype streamFrameType)
}
}

c.streams.streamsMu.Lock()
defer c.streams.streamsMu.Unlock()
s, isOpen := c.streams.streams[id]
if s != nil {
return s
Expand Down
44 changes: 44 additions & 0 deletions internal/quic/conn_streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"io"
"math"
"sync"
"testing"
)

Expand Down Expand Up @@ -478,3 +479,46 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) {
t.Fatalf("after test, stream send queue is not empty; should be")
}
}

func TestStreamsCreateConcurrency(t *testing.T) {
cli, srv := newLocalConnPair(t, &Config{}, &Config{})

srvdone := make(chan int)
go func() {
defer close(srvdone)
for streams := 0; ; streams++ {
s, err := srv.AcceptStream(context.Background())
if err != nil {
srvdone <- streams
return
}
s.Close()
}
}()

var wg sync.WaitGroup
const concurrency = 10
const streams = 10
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < streams; j++ {
s, err := cli.NewStream(context.Background())
if err != nil {
t.Errorf("NewStream: %v", err)
return
}
s.Flush()
s.Close()
}
}()
}
wg.Wait()

cli.Abort(nil)
srv.Abort(nil)
if got, want := <-srvdone, concurrency*streams; got != want {
t.Errorf("accepted %v streams, want %v", got, want)
}
}
19 changes: 12 additions & 7 deletions internal/quic/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,25 @@ func TestConnTestConn(t *testing.T) {
t.Errorf("new conn timeout=%v, want %v (max_idle_timeout)", got, want)
}

var ranAt time.Time
tc.conn.runOnLoop(func(now time.Time, c *Conn) {
ranAt = now
})
ranAt, _ := runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(tc.endpoint.now) {
t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now)
}
tc.wait()

nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2)
tc.advanceTo(nextTime)
tc.conn.runOnLoop(func(now time.Time, c *Conn) {
ranAt = now
})
ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) {
tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) {
when = now
})
return
}).result()
if !ranAt.Equal(nextTime) {
t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime)
}
Expand Down

0 comments on commit b952594

Please sign in to comment.