diff --git a/ack_timer.go b/ack_timer.go index 879e86df..b6008d18 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -18,7 +18,7 @@ type ackTimerObserver interface { onAckTimeout() } -type ackTimerState int +type ackTimerState uint8 const ( ackTimerStopped ackTimerState = iota @@ -28,10 +28,11 @@ const ( // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { + timer *time.Timer observer ackTimerObserver - mutex sync.RWMutex + mutex sync.Mutex state ackTimerState - timer *time.Timer + pending uint8 } // newAckTimer creates a new acknowledgement timer used to enable delayed ack. @@ -44,7 +45,7 @@ func newAckTimer(observer ackTimerObserver) *ackTimer { func (t *ackTimer) timeout() { t.mutex.Lock() - if t.state == ackTimerStarted { + if t.pending--; t.pending == 0 && t.state == ackTimerStarted { t.state = ackTimerStopped defer t.observer.onAckTimeout() } @@ -62,6 +63,7 @@ func (t *ackTimer) start() bool { } t.state = ackTimerStarted + t.pending++ t.timer.Reset(ackInterval) return true } @@ -73,7 +75,9 @@ func (t *ackTimer) stop() { defer t.mutex.Unlock() if t.state == ackTimerStarted { - t.timer.Stop() + if t.timer.Stop() { + t.pending-- + } t.state = ackTimerStopped } } @@ -84,8 +88,8 @@ func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.state == ackTimerStarted { - t.timer.Stop() + if t.state == ackTimerStarted && t.timer.Stop() { + t.pending-- } t.state = ackTimerClosed } @@ -93,8 +97,8 @@ func (t *ackTimer) close() { // isRunning tests if the timer is running. // Debug purpose only func (t *ackTimer) isRunning() bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + t.mutex.Lock() + defer t.mutex.Unlock() return t.state == ackTimerStarted } diff --git a/rtx_timer.go b/rtx_timer.go index 354825b5..1fea3931 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -118,19 +118,28 @@ type rtxTimerObserver interface { onRetransmissionFailure(timerID int) } +type rtxTimerState uint8 + +const ( + rtxTimerStopped rtxTimerState = iota + rtxTimerStarted + rtxTimerClosed +) + // rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type rtxTimer struct { - id int + timer *time.Timer observer rtxTimerObserver + id int maxRetrans uint - stopFunc stopTimerLoop - closed bool - mutex sync.RWMutex rtoMax float64 + mutex sync.Mutex + rto float64 + nRtos uint + state rtxTimerState + pending uint8 } -type stopTimerLoop func() - // newRTXTimer creates a new retransmission timer. // if maxRetrans is set to 0, it will keep retransmitting until stop() is called. // (it will never make onRetransmissionFailure() callback. @@ -146,21 +155,38 @@ func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, if timer.rtoMax == 0 { timer.rtoMax = defaultRTOMax } + timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout) + timer.timer.Stop() return &timer } +func (t *rtxTimer) calculateNextTimeout() time.Duration { + timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax) + return time.Duration(timeout) * time.Millisecond +} + +func (t *rtxTimer) timeout() { + t.mutex.Lock() + if t.pending--; t.pending == 0 && t.state == rtxTimerStarted { + if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans { + t.timer.Reset(t.calculateNextTimeout()) + t.pending++ + defer t.observer.onRetransmissionTimeout(t.id, t.nRtos) + } else { + t.state = rtxTimerStopped + defer t.observer.onRetransmissionFailure(t.id) + } + } + t.mutex.Unlock() +} + // start starts the timer. func (t *rtxTimer) start(rto float64) bool { t.mutex.Lock() defer t.mutex.Unlock() - // this timer is already closed - if t.closed { - return false - } - - // this is a noop if the timer is always running - if t.stopFunc != nil { + // this timer is already closed or aleady running + if t.state != rtxTimerStopped { return false } @@ -168,40 +194,11 @@ func (t *rtxTimer) start(rto float64) bool { // fast timeout for the tests. Non-test code should pass in the // rto generated by rtoManager getRTO() method which caps the // value at RTO.Min or at RTO.Max. - var nRtos uint - - cancelCh := make(chan struct{}) - - go func() { - canceling := false - - timer := time.NewTimer(math.MaxInt64) - timer.Stop() - - for !canceling { - timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) - timer.Reset(time.Duration(timeout) * time.Millisecond) - - select { - case <-timer.C: - nRtos++ - if t.maxRetrans == 0 || nRtos <= t.maxRetrans { - t.observer.onRetransmissionTimeout(t.id, nRtos) - } else { - t.stop() - t.observer.onRetransmissionFailure(t.id) - } - case <-cancelCh: - canceling = true - timer.Stop() - } - } - }() - - t.stopFunc = func() { - close(cancelCh) - } - + t.rto = rto + t.nRtos = 0 + t.state = rtxTimerStarted + t.pending++ + t.timer.Reset(t.calculateNextTimeout()) return true } @@ -210,9 +207,11 @@ func (t *rtxTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == rtxTimerStarted { + if t.timer.Stop() { + t.pending-- + } + t.state = rtxTimerStopped } } @@ -222,21 +221,19 @@ func (t *rtxTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == rtxTimerStarted && t.timer.Stop() { + t.pending-- } - - t.closed = true + t.state = rtxTimerClosed } // isRunning tests if the timer is running. // Debug purpose only func (t *rtxTimer) isRunning() bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + t.mutex.Lock() + defer t.mutex.Unlock() - return (t.stopFunc != nil) + return t.state == rtxTimerStarted } func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 {