Skip to content

Commit

Permalink
feat: add WaitForWithContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ccoVeille committed Jul 2, 2024
1 parent 071a746 commit e1d8c98
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 9 deletions.
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ Concurrency helpers:
- [Async](#async)
- [Transaction](#transaction)
- [WaitFor](#waitfor)
- [WaitForWithContext](#waitforwithcontext)

Error handling:

Expand Down Expand Up @@ -3089,6 +3090,49 @@ iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Mi
// false
```


### WaitForWithContext

Runs periodically until a condition is validated or context is invalid.

The condition receives also the context, so it can invalidate the process in the condition checker

```go
ctx := context.Background()

alwaysTrue := func(_ context.Context, i int) bool { return true }
alwaysFalse := func(_ context.Context, i int) bool { return false }
laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond)
// 1
// 1ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond)
// 10
// 10ms
// false

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond)
// 5
// 5ms
// true

iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond)
// 2
// 10ms
// false

expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond)
iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond)
// 5
// 5.1ms
// false
```

### Validate

Helper function that creates an error when a condition is not met.
Expand Down
29 changes: 20 additions & 9 deletions concurrency.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"time"
)
Expand Down Expand Up @@ -99,28 +100,38 @@ func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A,

// WaitFor runs periodically until a condition is validated.
func WaitFor(condition func(i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
conditionWithContext := func(_ context.Context, i int) bool {
return condition(i)
}
return WaitForWithContext(context.Background(), conditionWithContext, maxDuration, tick)
}

// WaitForWithContext runs periodically until a condition is validated or context is canceled.
func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, i int) bool, maxDuration time.Duration, tick time.Duration) (int, time.Duration, bool) {
start := time.Now()

timer := time.NewTimer(maxDuration)
i := 0
if ctx.Err() != nil {
return i, time.Since(start), false
}

ctx, cleanCtx := context.WithTimeout(ctx, maxDuration)
ticker := time.NewTicker(tick)

defer func() {
timer.Stop()
cleanCtx()
ticker.Stop()
}()

i := 0

for {
select {
case <-timer.C:
case <-ctx.Done():
return i, time.Since(start), false
case <-ticker.C:
if condition(i) {
return i + 1, time.Since(start), true
}

i++
if condition(ctx, i-1) {
return i, time.Since(start), true
}
}
}
}
116 changes: 116 additions & 0 deletions concurrency_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lo

import (
"context"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -295,3 +296,118 @@ func TestWaitFor(t *testing.T) {
is.True(ok)
})
}

func TestWaitForWithContext(t *testing.T) {
t.Parallel()

testTimeout := 100 * time.Millisecond
longTimeout := 2 * testTimeout
shortTimeout := 4 * time.Millisecond

t.Run("exist condition works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

laterTrue := func(_ context.Context, i int) bool {
return i >= 5
}

iter, duration, ok := WaitForWithContext(context.Background(), laterTrue, longTimeout, time.Millisecond)
is.Equal(6, iter, "unexpected iteration count")
is.InEpsilon(6*time.Millisecond, duration, float64(500*time.Microsecond))
is.True(ok)
})

t.Run("counter is incremented", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

counter := 0
alwaysFalse := func(_ context.Context, i int) bool {
is.Equal(counter, i)
counter++
return false
}

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 1050*time.Microsecond)
is.Equal(counter, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

alwaysTrue := func(_ context.Context, _ int) bool { return true }
alwaysFalse := func(_ context.Context, _ int) bool { return false }

t.Run("short timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("timeout works", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

shortTimeout := 4 * time.Millisecond
iter, duration, ok := WaitForWithContext(context.Background(), alwaysFalse, shortTimeout, 10*time.Millisecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("exist on first condition", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

iter, duration, ok := WaitForWithContext(context.Background(), alwaysTrue, 10*time.Millisecond, time.Millisecond)
is.Equal(1, iter, "unexpected iteration count")
is.InEpsilon(time.Millisecond, duration, float64(5*time.Microsecond))
is.True(ok)
})

t.Run("context cancellation stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

expiringCtx, clean := context.WithTimeout(context.Background(), 8*time.Millisecond)
t.Cleanup(func() {
clean()
})

iter, duration, ok := WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, 3*time.Millisecond)
is.Equal(2, iter, "unexpected iteration count")
is.InEpsilon(10*time.Millisecond, duration, float64(500*time.Microsecond))
is.False(ok)
})

t.Run("canceled context stops everything", func(t *testing.T) {
t.Parallel()

testWithTimeout(t, testTimeout)
is := assert.New(t)

canceledCtx, cancel := context.WithCancel(context.Background())
cancel()

iter, duration, ok := WaitForWithContext(canceledCtx, alwaysFalse, 100*time.Millisecond, 1050*time.Microsecond)
is.Equal(0, iter, "unexpected iteration count")
is.InEpsilon(1*time.Millisecond, duration, float64(5*time.Microsecond))
is.False(ok)
})
}

0 comments on commit e1d8c98

Please sign in to comment.