-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* attempt to implement waitAll * WaitAll with options * cover context cancel case in WaitAll * update test comments * waitAll * prevent unittest from stuck * limit all test to finish in 3 seconds * tweaks * tweaks * tweak * - distinguash softClose and hardClose - waitOne as a function
- Loading branch information
Showing
4 changed files
with
233 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package asynctask | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
) | ||
|
||
// WaitAllOptions defines options for WaitAll function | ||
type WaitAllOptions struct { | ||
// FailFast set to true will indicate WaitAll to return on first error it sees. | ||
FailFast bool | ||
} | ||
|
||
// WaitAll block current thread til all task finished. | ||
// first error from any tasks passed in will be returned. | ||
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus) error { | ||
tasksCount := len(tasks) | ||
|
||
mutex := sync.Mutex{} | ||
errorChClosed := false | ||
errorCh := make(chan error, tasksCount) | ||
// hard close channel | ||
defer close(errorCh) | ||
|
||
for _, tsk := range tasks { | ||
go waitOne(ctx, tsk, errorCh, &errorChClosed, &mutex) | ||
} | ||
|
||
runningTasks := tasksCount | ||
var errList []error | ||
for { | ||
select { | ||
case err := <-errorCh: | ||
runningTasks-- | ||
if err != nil { | ||
// return immediately after receive first error. | ||
if options.FailFast { | ||
softCloseChannel(&mutex, &errorChClosed) | ||
return err | ||
} | ||
|
||
errList = append(errList, err) | ||
} | ||
case <-ctx.Done(): | ||
softCloseChannel(&mutex, &errorChClosed) | ||
return fmt.Errorf("WaitAll context canceled: %w", ctx.Err()) | ||
} | ||
|
||
// are we finished yet? | ||
if runningTasks == 0 { | ||
softCloseChannel(&mutex, &errorChClosed) | ||
break | ||
} | ||
} | ||
|
||
// we have at least 1 error, return first one. | ||
// caller can get error for individual task by using Wait(), | ||
// it would return immediately after this WaitAll() | ||
if len(errList) > 0 { | ||
return errList[0] | ||
} | ||
|
||
// no error at all. | ||
return nil | ||
} | ||
|
||
func waitOne(ctx context.Context, tsk *TaskStatus, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) { | ||
_, err := tsk.Wait(ctx) | ||
|
||
// why mutex? | ||
// if all tasks start using same context (unittest is good example) | ||
// and that context got canceled, all task fail at same time. | ||
// first one went in and close the channel, while another one already went through gate check. | ||
// raise a panic with send to closed channel. | ||
mutex.Lock() | ||
defer mutex.Unlock() | ||
if !*errorChClosed { | ||
errorCh <- err | ||
} | ||
} | ||
|
||
func softCloseChannel(mutex *sync.Mutex, closed *bool) { | ||
mutex.Lock() | ||
defer mutex.Unlock() | ||
*closed = true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package asynctask_test | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
"github.com/Azure/go-asynctask" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestWaitAll(t *testing.T) { | ||
t.Parallel() | ||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) | ||
defer cancelFunc() | ||
|
||
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond)) | ||
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond)) | ||
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond)) | ||
completedTsk := asynctask.NewCompletedTask() | ||
|
||
start := time.Now() | ||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk) | ||
elapsed := time.Since(start) | ||
assert.NoError(t, err) | ||
// should only finish after longest task. | ||
assert.True(t, elapsed > 10*200*time.Millisecond) | ||
} | ||
|
||
func TestWaitAllFailFastCase(t *testing.T) { | ||
t.Parallel() | ||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) | ||
defer cancelFunc() | ||
|
||
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond)) | ||
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) | ||
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond)) | ||
completedTsk := asynctask.NewCompletedTask() | ||
|
||
start := time.Now() | ||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk) | ||
countingTskState := countingTsk.State() | ||
panicTskState := countingTsk.State() | ||
elapsed := time.Since(start) | ||
assert.Error(t, err) | ||
assert.Equal(t, "expected error", err.Error()) | ||
// should fail before we finish panic task | ||
assert.True(t, elapsed.Milliseconds() < 15) | ||
|
||
// since we pass FailFast, countingTsk and panicTsk should be still running | ||
assert.Equal(t, asynctask.StateRunning, countingTskState) | ||
assert.Equal(t, asynctask.StateRunning, panicTskState) | ||
} | ||
|
||
func TestWaitAllErrorCase(t *testing.T) { | ||
t.Parallel() | ||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) | ||
defer cancelFunc() | ||
|
||
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond)) | ||
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) | ||
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond)) | ||
completedTsk := asynctask.NewCompletedTask() | ||
|
||
start := time.Now() | ||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk) | ||
countingTskState := countingTsk.State() | ||
panicTskState := panicTsk.State() | ||
elapsed := time.Since(start) | ||
assert.Error(t, err) | ||
assert.Equal(t, "expected error", err.Error()) | ||
// should only finish after longest task. | ||
assert.True(t, elapsed > 10*200*time.Millisecond) | ||
|
||
// since we pass FailFast, countingTsk and panicTsk should be still running | ||
assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should finished") | ||
assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should failed") | ||
} | ||
|
||
func TestWaitAllCanceled(t *testing.T) { | ||
t.Parallel() | ||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) | ||
defer cancelFunc() | ||
|
||
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond)) | ||
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond)) | ||
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond)) | ||
completedTsk := asynctask.NewCompletedTask() | ||
|
||
waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond) | ||
defer cancelFunc1() | ||
|
||
start := time.Now() | ||
err := asynctask.WaitAll(waitCtx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk) | ||
elapsed := time.Since(start) | ||
assert.Error(t, err) | ||
t.Log(err.Error()) | ||
assert.True(t, errors.Is(err, context.DeadlineExceeded)) | ||
// should return before first task | ||
assert.True(t, elapsed < 10*2*time.Millisecond) | ||
} |