From e23ff7ac4766737136ff9ce375a9759260931adc Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Mon, 1 Jul 2024 04:09:59 -0700 Subject: [PATCH] Fix inconsistencies in timer implementations Replaces goroutine based ack timer with callback pattern used in other parts of the project. Adds pending timer tick counters to prevent races between Stop and timers firing. This fixes an edge case where callacks may have fired after stop was called. --- ack_timer.go | 22 ++++++---- rtx_timer.go | 113 +++++++++++++++++++++++++-------------------------- 2 files changed, 68 insertions(+), 67 deletions(-) diff --git a/ack_timer.go b/ack_timer.go index 879e86df..b6008d18 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -18,7 +18,7 @@ type ackTimerObserver interface { onAckTimeout() } -type ackTimerState int +type ackTimerState uint8 const ( ackTimerStopped ackTimerState = iota @@ -28,10 +28,11 @@ const ( // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { + timer *time.Timer observer ackTimerObserver - mutex sync.RWMutex + mutex sync.Mutex state ackTimerState - timer *time.Timer + pending uint8 } // newAckTimer creates a new acknowledgement timer used to enable delayed ack. @@ -44,7 +45,7 @@ func newAckTimer(observer ackTimerObserver) *ackTimer { func (t *ackTimer) timeout() { t.mutex.Lock() - if t.state == ackTimerStarted { + if t.pending--; t.pending == 0 && t.state == ackTimerStarted { t.state = ackTimerStopped defer t.observer.onAckTimeout() } @@ -62,6 +63,7 @@ func (t *ackTimer) start() bool { } t.state = ackTimerStarted + t.pending++ t.timer.Reset(ackInterval) return true } @@ -73,7 +75,9 @@ func (t *ackTimer) stop() { defer t.mutex.Unlock() if t.state == ackTimerStarted { - t.timer.Stop() + if t.timer.Stop() { + t.pending-- + } t.state = ackTimerStopped } } @@ -84,8 +88,8 @@ func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.state == ackTimerStarted { - t.timer.Stop() + if t.state == ackTimerStarted && t.timer.Stop() { + t.pending-- } t.state = ackTimerClosed } @@ -93,8 +97,8 @@ func (t *ackTimer) close() { // isRunning tests if the timer is running. // Debug purpose only func (t *ackTimer) isRunning() bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + t.mutex.Lock() + defer t.mutex.Unlock() return t.state == ackTimerStarted } diff --git a/rtx_timer.go b/rtx_timer.go index 354825b5..1fea3931 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -118,19 +118,28 @@ type rtxTimerObserver interface { onRetransmissionFailure(timerID int) } +type rtxTimerState uint8 + +const ( + rtxTimerStopped rtxTimerState = iota + rtxTimerStarted + rtxTimerClosed +) + // rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type rtxTimer struct { - id int + timer *time.Timer observer rtxTimerObserver + id int maxRetrans uint - stopFunc stopTimerLoop - closed bool - mutex sync.RWMutex rtoMax float64 + mutex sync.Mutex + rto float64 + nRtos uint + state rtxTimerState + pending uint8 } -type stopTimerLoop func() - // newRTXTimer creates a new retransmission timer. // if maxRetrans is set to 0, it will keep retransmitting until stop() is called. // (it will never make onRetransmissionFailure() callback. @@ -146,21 +155,38 @@ func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, if timer.rtoMax == 0 { timer.rtoMax = defaultRTOMax } + timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout) + timer.timer.Stop() return &timer } +func (t *rtxTimer) calculateNextTimeout() time.Duration { + timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax) + return time.Duration(timeout) * time.Millisecond +} + +func (t *rtxTimer) timeout() { + t.mutex.Lock() + if t.pending--; t.pending == 0 && t.state == rtxTimerStarted { + if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans { + t.timer.Reset(t.calculateNextTimeout()) + t.pending++ + defer t.observer.onRetransmissionTimeout(t.id, t.nRtos) + } else { + t.state = rtxTimerStopped + defer t.observer.onRetransmissionFailure(t.id) + } + } + t.mutex.Unlock() +} + // start starts the timer. func (t *rtxTimer) start(rto float64) bool { t.mutex.Lock() defer t.mutex.Unlock() - // this timer is already closed - if t.closed { - return false - } - - // this is a noop if the timer is always running - if t.stopFunc != nil { + // this timer is already closed or aleady running + if t.state != rtxTimerStopped { return false } @@ -168,40 +194,11 @@ func (t *rtxTimer) start(rto float64) bool { // fast timeout for the tests. Non-test code should pass in the // rto generated by rtoManager getRTO() method which caps the // value at RTO.Min or at RTO.Max. - var nRtos uint - - cancelCh := make(chan struct{}) - - go func() { - canceling := false - - timer := time.NewTimer(math.MaxInt64) - timer.Stop() - - for !canceling { - timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) - timer.Reset(time.Duration(timeout) * time.Millisecond) - - select { - case <-timer.C: - nRtos++ - if t.maxRetrans == 0 || nRtos <= t.maxRetrans { - t.observer.onRetransmissionTimeout(t.id, nRtos) - } else { - t.stop() - t.observer.onRetransmissionFailure(t.id) - } - case <-cancelCh: - canceling = true - timer.Stop() - } - } - }() - - t.stopFunc = func() { - close(cancelCh) - } - + t.rto = rto + t.nRtos = 0 + t.state = rtxTimerStarted + t.pending++ + t.timer.Reset(t.calculateNextTimeout()) return true } @@ -210,9 +207,11 @@ func (t *rtxTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == rtxTimerStarted { + if t.timer.Stop() { + t.pending-- + } + t.state = rtxTimerStopped } } @@ -222,21 +221,19 @@ func (t *rtxTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == rtxTimerStarted && t.timer.Stop() { + t.pending-- } - - t.closed = true + t.state = rtxTimerClosed } // isRunning tests if the timer is running. // Debug purpose only func (t *rtxTimer) isRunning() bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + t.mutex.Lock() + defer t.mutex.Unlock() - return (t.stopFunc != nil) + return t.state == rtxTimerStarted } func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 {