Skip to content

Commit

Permalink
Fix inconsistencies in timer implementations
Browse files Browse the repository at this point in the history
Replaces goroutine based ack timer with callback pattern used in
other parts of the project.

Adds pending timer tick counters to prevent races between Stop and
timers firing. This fixes an edge case where callacks may have
fired after stop was called.
  • Loading branch information
paulwe committed Jul 1, 2024
1 parent a8bc9b8 commit e23ff7a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 67 deletions.
22 changes: 13 additions & 9 deletions ack_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ackTimerObserver interface {
onAckTimeout()
}

type ackTimerState int
type ackTimerState uint8

const (
ackTimerStopped ackTimerState = iota
Expand All @@ -28,10 +28,11 @@ const (

// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1
type ackTimer struct {
timer *time.Timer
observer ackTimerObserver
mutex sync.RWMutex
mutex sync.Mutex
state ackTimerState
timer *time.Timer
pending uint8
}

// newAckTimer creates a new acknowledgement timer used to enable delayed ack.
Expand All @@ -44,7 +45,7 @@ func newAckTimer(observer ackTimerObserver) *ackTimer {

func (t *ackTimer) timeout() {
t.mutex.Lock()
if t.state == ackTimerStarted {
if t.pending--; t.pending == 0 && t.state == ackTimerStarted {
t.state = ackTimerStopped
defer t.observer.onAckTimeout()
}
Expand All @@ -62,6 +63,7 @@ func (t *ackTimer) start() bool {
}

t.state = ackTimerStarted
t.pending++
t.timer.Reset(ackInterval)
return true
}
Expand All @@ -73,7 +75,9 @@ func (t *ackTimer) stop() {
defer t.mutex.Unlock()

if t.state == ackTimerStarted {
t.timer.Stop()
if t.timer.Stop() {
t.pending--
}
t.state = ackTimerStopped
}
}
Expand All @@ -84,17 +88,17 @@ func (t *ackTimer) close() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.state == ackTimerStarted {
t.timer.Stop()
if t.state == ackTimerStarted && t.timer.Stop() {
t.pending--
}
t.state = ackTimerClosed
}

// isRunning tests if the timer is running.
// Debug purpose only
func (t *ackTimer) isRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
t.mutex.Lock()
defer t.mutex.Unlock()

return t.state == ackTimerStarted
}
113 changes: 55 additions & 58 deletions rtx_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,28 @@ type rtxTimerObserver interface {
onRetransmissionFailure(timerID int)
}

type rtxTimerState uint8

const (
rtxTimerStopped rtxTimerState = iota
rtxTimerStarted
rtxTimerClosed
)

// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1
type rtxTimer struct {
id int
timer *time.Timer
observer rtxTimerObserver
id int
maxRetrans uint
stopFunc stopTimerLoop
closed bool
mutex sync.RWMutex
rtoMax float64
mutex sync.Mutex
rto float64
nRtos uint
state rtxTimerState
pending uint8
}

type stopTimerLoop func()

// newRTXTimer creates a new retransmission timer.
// if maxRetrans is set to 0, it will keep retransmitting until stop() is called.
// (it will never make onRetransmissionFailure() callback.
Expand All @@ -146,62 +155,50 @@ func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint,
if timer.rtoMax == 0 {
timer.rtoMax = defaultRTOMax
}
timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout)
timer.timer.Stop()
return &timer
}

func (t *rtxTimer) calculateNextTimeout() time.Duration {
timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax)
return time.Duration(timeout) * time.Millisecond
}

func (t *rtxTimer) timeout() {
t.mutex.Lock()
if t.pending--; t.pending == 0 && t.state == rtxTimerStarted {
if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans {
t.timer.Reset(t.calculateNextTimeout())
t.pending++
defer t.observer.onRetransmissionTimeout(t.id, t.nRtos)
} else {
t.state = rtxTimerStopped
defer t.observer.onRetransmissionFailure(t.id)
}
}
t.mutex.Unlock()
}

// start starts the timer.
func (t *rtxTimer) start(rto float64) bool {
t.mutex.Lock()
defer t.mutex.Unlock()

// this timer is already closed
if t.closed {
return false
}

// this is a noop if the timer is always running
if t.stopFunc != nil {
// this timer is already closed or aleady running
if t.state != rtxTimerStopped {
return false
}

// Note: rto value is intentionally not capped by RTO.Min to allow
// fast timeout for the tests. Non-test code should pass in the
// rto generated by rtoManager getRTO() method which caps the
// value at RTO.Min or at RTO.Max.
var nRtos uint

cancelCh := make(chan struct{})

go func() {
canceling := false

timer := time.NewTimer(math.MaxInt64)
timer.Stop()

for !canceling {
timeout := calculateNextTimeout(rto, nRtos, t.rtoMax)
timer.Reset(time.Duration(timeout) * time.Millisecond)

select {
case <-timer.C:
nRtos++
if t.maxRetrans == 0 || nRtos <= t.maxRetrans {
t.observer.onRetransmissionTimeout(t.id, nRtos)
} else {
t.stop()
t.observer.onRetransmissionFailure(t.id)
}
case <-cancelCh:
canceling = true
timer.Stop()
}
}
}()

t.stopFunc = func() {
close(cancelCh)
}

t.rto = rto
t.nRtos = 0
t.state = rtxTimerStarted
t.pending++
t.timer.Reset(t.calculateNextTimeout())
return true
}

Expand All @@ -210,9 +207,11 @@ func (t *rtxTimer) stop() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == rtxTimerStarted {
if t.timer.Stop() {
t.pending--
}
t.state = rtxTimerStopped
}
}

Expand All @@ -222,21 +221,19 @@ func (t *rtxTimer) close() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == rtxTimerStarted && t.timer.Stop() {
t.pending--
}

t.closed = true
t.state = rtxTimerClosed
}

// isRunning tests if the timer is running.
// Debug purpose only
func (t *rtxTimer) isRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
t.mutex.Lock()
defer t.mutex.Unlock()

return (t.stopFunc != nil)
return t.state == rtxTimerStarted
}

func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 {
Expand Down

0 comments on commit e23ff7a

Please sign in to comment.