diff --git a/wait_any.go b/wait_any.go new file mode 100644 index 0000000..1d5cb51 --- /dev/null +++ b/wait_any.go @@ -0,0 +1,67 @@ +package asynctask + +import ( + "context" + "fmt" +) + +// WaitAnyOptions defines options for WaitAny function +type WaitAnyOptions struct { + // FailOnAnyError set to true will indicate WaitAny to return on first error it sees. + FailOnAnyError bool +} + +// WaitAny block current thread til any of task finished. +// first error from any tasks passed in will be returned if FailOnAnyError is set. +// first task end without error will end wait and return nil +func WaitAny(ctx context.Context, options *WaitAnyOptions, tasks ...Waitable) error { + tasksCount := len(tasks) + if tasksCount == 0 { + return nil + } + + if options == nil { + options = &WaitAnyOptions{} + } + + // tried to close channel before exit this func, + // but it's complicated with routines, and we don't want to delay the return. + // per https://stackoverflow.com/questions/8593645/is-it-ok-to-leave-a-channel-open, its ok to leave channel open, eventually it will be garbage collected. + // this assumes the tasks eventually finish, otherwise we will have a routine leak. + errorCh := make(chan error, tasksCount) + + for _, tsk := range tasks { + go waitOne(ctx, tsk, errorCh) + } + + runningTasks := tasksCount + var errList []error + for { + select { + case err := <-errorCh: + runningTasks-- + if err != nil { + // return immediately after receive first error if FailOnAnyError is set. + if options.FailOnAnyError { + return err + } + errList = append(errList, err) + } else { + // return immediately after first task completed. + return nil + } + case <-ctx.Done(): + return fmt.Errorf("WaitAny %w", ctx.Err()) + } + + // are we finished yet? + if runningTasks == 0 { + break + } + } + + // when all tasks failed and FailOnAnyError is not set, return first one. + // caller can get error for individual task by using Wait(), + // it would return immediately after this WaitAny() + return errList[0] +} diff --git a/wait_any_test.go b/wait_any_test.go new file mode 100644 index 0000000..754c99d --- /dev/null +++ b/wait_any_test.go @@ -0,0 +1,190 @@ +package asynctask_test + +import ( + "fmt" + "testing" + "time" + + "github.com/Azure/go-asynctask" + "github.com/stretchr/testify/assert" +) + +func TestWaitAnyNoTask(t *testing.T) { + t.Parallel() + ctx, _ := newTestContextWithTimeout(t, 2*time.Second) + + err := asynctask.WaitAny(ctx, nil) + assert.NoError(t, err) +} + +func TestWaitAny(t *testing.T) { + t.Parallel() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 2*time.Second) + + start := time.Now() + countingTsk3 := asynctask.Start(ctx, getCountingTask(10, "countingPer2ms", 2*time.Millisecond)) + result := "something" + completedTsk := asynctask.NewCompletedTask(&result) + + err := asynctask.WaitAny(ctx, nil, countingTsk3, completedTsk) + elapsed := time.Since(start) + assert.NoError(t, err) + // should finish after right away + assert.True(t, elapsed < 2*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + start = time.Now() + countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) + countingTsk2 := asynctask.Start(ctx, getCountingTask(10, "countingPer20ms", 20*time.Millisecond)) + countingTsk3 = asynctask.Start(ctx, getCountingTask(10, "countingPer2ms", 2*time.Millisecond)) + err = asynctask.WaitAny(ctx, &asynctask.WaitAnyOptions{FailOnAnyError: true}, countingTsk1, countingTsk2, countingTsk3) + elapsed = time.Since(start) + assert.NoError(t, err) + cancelTaskExecution() + + // should finish right after countingTsk3 + assert.True(t, elapsed >= 20*time.Millisecond && elapsed < 200*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(1 * time.Millisecond) +} + +func TestWaitAnyContextCancel(t *testing.T) { + t.Parallel() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 2*time.Second) + + start := time.Now() + + countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) + countingTsk2 := asynctask.Start(ctx, getCountingTask(10, "countingPer20ms", 20*time.Millisecond)) + go func() { + time.Sleep(5 * time.Millisecond) + cancelTaskExecution() + }() + err := asynctask.WaitAny(ctx, nil, countingTsk1, countingTsk2) + elapsed := time.Since(start) + assert.Error(t, err) + assert.Equal(t, "WaitAny context canceled", err.Error(), "expecting context canceled error") + // should finish right after countingTsk3 + assert.True(t, elapsed >= 5*time.Millisecond && elapsed < 200*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(1 * time.Millisecond) +} + +func TestWaitAnyErrorCase(t *testing.T) { + t.Parallel() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) + + start := time.Now() + errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) + result := "something" + completedTsk := asynctask.NewCompletedTask(&result) + err := asynctask.WaitAny(ctx, nil, errorTsk, completedTsk) + assert.NoError(t, err) + elapsed := time.Since(start) + // should finish after right away + assert.True(t, elapsed < 20*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + completedTskState := completedTsk.State() + assert.Equal(t, asynctask.StateCompleted, completedTskState, "completed task should finished") + + start = time.Now() + countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) + errorTsk = asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) + panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond)) + err = asynctask.WaitAny(ctx, nil, countingTsk, errorTsk, panicTsk) + // there is a succeed task + assert.NoError(t, err) + elapsed = time.Since(start) + + countingTskState := countingTsk.State() + panicTskState := panicTsk.State() + errTskState := errorTsk.State() + cancelTaskExecution() // all assertion variable captured, cancel counting task + + // should only finish after longest task. + assert.True(t, elapsed >= 40*10*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should NOT finished") + assert.Equal(t, asynctask.StateFailed, errTskState, "error task should failed") + assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should Not failed") + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(1 * time.Millisecond) +} + +func TestWaitAnyAllFailCase(t *testing.T) { + t.Parallel() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) + + start := time.Now() + errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) + panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond)) + err := asynctask.WaitAny(ctx, nil, errorTsk, panicTsk) + assert.Error(t, err) + + panicTskState := panicTsk.State() + errTskState := errorTsk.State() + elapsed := time.Since(start) + cancelTaskExecution() // all assertion variable captured, cancel counting task + + assert.Equal(t, "expected error", err.Error(), "expecting first error") + // should finsh after both error + assert.True(t, elapsed >= 20*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + assert.Equal(t, asynctask.StateFailed, errTskState, "error task should failed") + assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should Not failed") + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(1 * time.Millisecond) +} + +func TestWaitAnyErrorWithFailOnAnyErrorCase(t *testing.T) { + t.Parallel() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) + + start := time.Now() + errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) + result := "something" + completedTsk := asynctask.NewCompletedTask(&result) + err := asynctask.WaitAny(ctx, &asynctask.WaitAnyOptions{FailOnAnyError: true}, errorTsk, completedTsk) + assert.NoError(t, err) + elapsed := time.Since(start) + // should finish after right away + assert.True(t, elapsed < 20*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + start = time.Now() + countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) + errorTsk = asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond)) + panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond)) + err = asynctask.WaitAny(ctx, &asynctask.WaitAnyOptions{FailOnAnyError: true}, countingTsk, errorTsk, panicTsk) + assert.Error(t, err) + completedTskState := completedTsk.State() + assert.Equal(t, asynctask.StateCompleted, completedTskState, "completed task should finished") + + countingTskState := countingTsk.State() + panicTskState := panicTsk.State() + errTskState := errorTsk.State() + elapsed = time.Since(start) + cancelTaskExecution() // all assertion variable captured, cancel counting task + + assert.Equal(t, "expected error", err.Error(), "expecting first error") + // should finsh after first error + assert.True(t, elapsed >= 10*time.Millisecond && elapsed < 20*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) + + assert.Equal(t, asynctask.StateRunning, countingTskState, "countingTask should NOT finished") + assert.Equal(t, asynctask.StateFailed, errTskState, "error task should failed") + assert.Equal(t, asynctask.StateRunning, panicTskState, "panic task should Not failed") + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(1 * time.Millisecond) +}