Skip to content

Commit

Permalink
implement waitAll (#5)
Browse files Browse the repository at this point in the history
* attempt to implement waitAll

* WaitAll with options

* cover context cancel case in WaitAll

* update test comments

* waitAll

* prevent unittest from stuck

* limit all test to finish in 3 seconds

* tweaks

* tweaks

* tweak

* - distinguash softClose and hardClose
- waitOne as a function
  • Loading branch information
haitch committed May 20, 2020
1 parent fc5fa4c commit 1500ff5
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 25 deletions.
41 changes: 26 additions & 15 deletions async_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ func newTestContext(t *testing.T) context.Context {
return context.WithValue(context.TODO(), testContextKey, t)
}

func getCountingTask(sleepDuration time.Duration) asynctask.AsyncFunc {
func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithValue(context.TODO(), testContextKey, t), timeout)
}

func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
t := ctx.Value(testContextKey).(*testing.T)

result := 0
for i := 0; i < 10; i++ {
for i := 0; i < countTo; i++ {
select {
case <-time.After(sleepDuration):
case <-time.After(sleepInterval):
t.Logf(" working %d", i)
result = i
case <-ctx.Done():
Expand All @@ -38,9 +42,10 @@ func getCountingTask(sleepDuration time.Duration) asynctask.AsyncFunc {

func TestEasyCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")

rawResult, err := t1.Wait(ctx)
Expand All @@ -61,14 +66,15 @@ func TestEasyCase(t *testing.T) {
result = rawResult.(int)
assert.Equal(t, result, 9)

assert.True(t, elapsed.Microseconds() < 2, "Second wait should return immediately")
assert.True(t, elapsed.Microseconds() < 3, "Second wait should return immediately")
}

func TestCancelFunc(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")

time.Sleep(time.Second * 1)
Expand Down Expand Up @@ -98,10 +104,11 @@ func TestCancelFunc(t *testing.T) {

func TestConsistentResultAfterCancel(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
t2 := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
t2 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")

time.Sleep(time.Second * 1)
Expand All @@ -126,6 +133,8 @@ func TestConsistentResultAfterCancel(t *testing.T) {

func TestCompletedTask(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

tsk := asynctask.NewCompletedTask()
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should in CompletedState")
Expand All @@ -135,19 +144,21 @@ func TestCompletedTask(t *testing.T) {
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")

// you get nil result and nil error
result, err := tsk.Wait(context.TODO())
result, err := tsk.Wait(ctx)
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
assert.NoError(t, err)
assert.Nil(t, result)
}

func TestCrazyCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
numOfTasks := 10000
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

numOfTasks := 8000 // if you have --race switch on: limit on 8128 simultaneously alive goroutines is exceeded, dying
tasks := map[int]*asynctask.TaskStatus{}
for i := 0; i < numOfTasks; i++ {
tasks[i] = asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
tasks[i] = asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
}

time.Sleep(200 * time.Millisecond)
Expand Down
28 changes: 18 additions & 10 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@ func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc {
}
}

func getErrorTask(sleepDuration time.Duration) asynctask.AsyncFunc {
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
time.Sleep(sleepDuration)
return nil, errors.New("not found")
return nil, errors.New(errorString)
}
}

func TestTimeoutCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
tsk := asynctask.Start(ctx, getCountingTask(200*time.Millisecond))
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

tsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
assert.True(t, errors.Is(err, context.DeadlineExceeded), "expecting DeadlineExceeded")

Expand All @@ -57,25 +59,31 @@ func TestTimeoutCase(t *testing.T) {

func TestPanicCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

tsk := asynctask.Start(ctx, getPanicTask(200*time.Millisecond))
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
assert.True(t, errors.Is(err, asynctask.ErrPanic), "expecting ErrPanic")
}

func TestErrorCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
tsk := asynctask.Start(ctx, getErrorTask(200*time.Millisecond))
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

tsk := asynctask.Start(ctx, getErrorTask("dummy error", 200*time.Millisecond))
_, err := tsk.WaitWithTimeout(ctx, 300*time.Millisecond)
assert.Error(t, err)
assert.False(t, errors.Is(err, asynctask.ErrPanic), "not expecting ErrPanic")
assert.False(t, errors.Is(err, context.DeadlineExceeded), "not expecting DeadlineExceeded")
assert.Equal(t, "not found", err.Error())
assert.Equal(t, "dummy error", err.Error())
}

func TestPointerErrorCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

// nil point of a type that implement error
var pe *pointerError = nil
Expand All @@ -84,7 +92,6 @@ func TestPointerErrorCase(t *testing.T) {
// now you get a non-nil error
assert.False(t, err == nil, "reason this test is needed")

ctx := newTestContext(t)
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
time.Sleep(100 * time.Millisecond)
var pe *pointerError = nil
Expand All @@ -98,6 +105,8 @@ func TestPointerErrorCase(t *testing.T) {

func TestStructErrorCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

// nil point of a type that implement error
var se structError
Expand All @@ -106,7 +115,6 @@ func TestStructErrorCase(t *testing.T) {
// now you get a non-nil error
assert.False(t, err == nil, "reason this test is needed")

ctx := newTestContext(t)
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
time.Sleep(100 * time.Millisecond)
var se structError
Expand Down
87 changes: 87 additions & 0 deletions wait_all.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package asynctask

import (
"context"
"fmt"
"sync"
)

// WaitAllOptions defines options for WaitAll function
type WaitAllOptions struct {
// FailFast set to true will indicate WaitAll to return on first error it sees.
FailFast bool
}

// WaitAll block current thread til all task finished.
// first error from any tasks passed in will be returned.
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus) error {
tasksCount := len(tasks)

mutex := sync.Mutex{}
errorChClosed := false
errorCh := make(chan error, tasksCount)
// hard close channel
defer close(errorCh)

for _, tsk := range tasks {
go waitOne(ctx, tsk, errorCh, &errorChClosed, &mutex)
}

runningTasks := tasksCount
var errList []error
for {
select {
case err := <-errorCh:
runningTasks--
if err != nil {
// return immediately after receive first error.
if options.FailFast {
softCloseChannel(&mutex, &errorChClosed)
return err
}

errList = append(errList, err)
}
case <-ctx.Done():
softCloseChannel(&mutex, &errorChClosed)
return fmt.Errorf("WaitAll context canceled: %w", ctx.Err())
}

// are we finished yet?
if runningTasks == 0 {
softCloseChannel(&mutex, &errorChClosed)
break
}
}

// we have at least 1 error, return first one.
// caller can get error for individual task by using Wait(),
// it would return immediately after this WaitAll()
if len(errList) > 0 {
return errList[0]
}

// no error at all.
return nil
}

func waitOne(ctx context.Context, tsk *TaskStatus, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
_, err := tsk.Wait(ctx)

// why mutex?
// if all tasks start using same context (unittest is good example)
// and that context got canceled, all task fail at same time.
// first one went in and close the channel, while another one already went through gate check.
// raise a panic with send to closed channel.
mutex.Lock()
defer mutex.Unlock()
if !*errorChClosed {
errorCh <- err
}
}

func softCloseChannel(mutex *sync.Mutex, closed *bool) {
mutex.Lock()
defer mutex.Unlock()
*closed = true
}
102 changes: 102 additions & 0 deletions wait_all_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package asynctask_test

import (
"context"
"errors"
"testing"
"time"

"github.com/Azure/go-asynctask"
"github.com/stretchr/testify/assert"
)

func TestWaitAll(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()

start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
elapsed := time.Since(start)
assert.NoError(t, err)
// should only finish after longest task.
assert.True(t, elapsed > 10*200*time.Millisecond)
}

func TestWaitAllFailFastCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()

start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk)
countingTskState := countingTsk.State()
panicTskState := countingTsk.State()
elapsed := time.Since(start)
assert.Error(t, err)
assert.Equal(t, "expected error", err.Error())
// should fail before we finish panic task
assert.True(t, elapsed.Milliseconds() < 15)

// since we pass FailFast, countingTsk and panicTsk should be still running
assert.Equal(t, asynctask.StateRunning, countingTskState)
assert.Equal(t, asynctask.StateRunning, panicTskState)
}

func TestWaitAllErrorCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()

start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk)
countingTskState := countingTsk.State()
panicTskState := panicTsk.State()
elapsed := time.Since(start)
assert.Error(t, err)
assert.Equal(t, "expected error", err.Error())
// should only finish after longest task.
assert.True(t, elapsed > 10*200*time.Millisecond)

// since we pass FailFast, countingTsk and panicTsk should be still running
assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should finished")
assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should failed")
}

func TestWaitAllCanceled(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()

countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()

waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond)
defer cancelFunc1()

start := time.Now()
err := asynctask.WaitAll(waitCtx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
elapsed := time.Since(start)
assert.Error(t, err)
t.Log(err.Error())
assert.True(t, errors.Is(err, context.DeadlineExceeded))
// should return before first task
assert.True(t, elapsed < 10*2*time.Millisecond)
}

0 comments on commit 1500ff5

Please sign in to comment.