diff --git a/go.mod b/go.mod index bcbe68e..bc69bc4 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,9 @@ require github.com/stretchr/testify v1.8.2 require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6a56e69..ac4ccb3 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,13 @@ +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -10,8 +17,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/task.go b/task.go index 699c256..1c31861 100644 --- a/task.go +++ b/task.go @@ -12,8 +12,8 @@ import ( type AsyncFunc[T any] func(context.Context) (*T, error) // ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface. -// Action is function that runs without return anything -// Func is function that runs and return something +// - Action is function that runs without return anything +// - Func is function that runs and return something func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) { return func(ctx context.Context) (*interface{}, error) { return nil, action(ctx) @@ -134,7 +134,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu defer record.waitGroup.Done() defer func() { if r := recover(); r != nil { - err := fmt.Errorf("Panic cought: %v, StackTrace: %s, %w", r, debug.Stack(), ErrPanic) + err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic) record.finish(StateFailed, nil, err) } }() diff --git a/wait_all.go b/wait_all.go index 5554994..dd9b970 100644 --- a/wait_all.go +++ b/wait_all.go @@ -27,11 +27,11 @@ func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...Waitable) er options = &WaitAllOptions{} } + // tried to close channel before exit this func, + // but it's complicated with routines, and we don't want to delay the return. + // per https://stackoverflow.com/questions/8593645/is-it-ok-to-leave-a-channel-open, its ok to leave channel open, eventually it will be garbage collected. + // this assumes the tasks eventually finish, otherwise we will have a routine leak. errorCh := make(chan error, tasksCount) - // when failFast enabled, we return on first error we see, while other task may still post error in this channel. - if !options.FailFast { - defer close(errorCh) - } for _, tsk := range tasks { go waitOne(ctx, tsk, errorCh) diff --git a/wait_all_test.go b/wait_all_test.go index 9740e96..1231cb8 100644 --- a/wait_all_test.go +++ b/wait_all_test.go @@ -13,8 +13,7 @@ import ( func TestWaitAll(t *testing.T) { t.Parallel() - ctx, cancelFunc := newTestContextWithTimeout(t, 2*time.Second) - defer cancelFunc() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 2*time.Second) start := time.Now() countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) @@ -26,13 +25,15 @@ func TestWaitAll(t *testing.T) { err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk) elapsed := time.Since(start) assert.NoError(t, err) + cancelTaskExecution() + // should only finish after longest task. assert.True(t, elapsed > 10*40*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) } func TestWaitAllFailFastCase(t *testing.T) { t.Parallel() - ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) start := time.Now() countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) @@ -44,9 +45,10 @@ func TestWaitAllFailFastCase(t *testing.T) { err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk) countingTskState := countingTsk.State() panicTskState := countingTsk.State() + errTskState := errorTsk.State() elapsed := time.Since(start) - cancelFunc() // all assertion variable captured, cancel counting task + cancelTaskExecution() // all assertion variable captured, cancel counting task assert.Error(t, err) assert.Equal(t, "expected error", err.Error()) @@ -56,6 +58,7 @@ func TestWaitAllFailFastCase(t *testing.T) { // since we pass FailFast, countingTsk and panicTsk should be still running assert.Equal(t, asynctask.StateRunning, countingTskState) assert.Equal(t, asynctask.StateRunning, panicTskState) + assert.Equal(t, asynctask.StateFailed, errTskState, "error task should the one failed the waitAll.") // counting task do testing.Logf in another go routine // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 @@ -65,8 +68,7 @@ func TestWaitAllFailFastCase(t *testing.T) { func TestWaitAllErrorCase(t *testing.T) { t.Parallel() - ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) - defer cancelFunc() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) start := time.Now() countingTsk := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) @@ -75,20 +77,24 @@ func TestWaitAllErrorCase(t *testing.T) { result := "something" completedTsk := asynctask.NewCompletedTask(&result) - err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk) + err := asynctask.WaitAll(ctx, nil, countingTsk, errorTsk, panicTsk, completedTsk) countingTskState := countingTsk.State() panicTskState := panicTsk.State() + errTskState := errorTsk.State() + completedTskState := completedTsk.State() elapsed := time.Since(start) - cancelFunc() // all assertion variable captured, cancel counting task + cancelTaskExecution() // all assertion variable captured, cancel counting task assert.Error(t, err) - assert.Equal(t, "expected error", err.Error()) + assert.Equal(t, "expected error", err.Error(), "expecting first error") // should only finish after longest task. assert.True(t, elapsed > 10*40*time.Millisecond, fmt.Sprintf("actually elapsed: %v", elapsed)) assert.Equal(t, asynctask.StateCompleted, countingTskState, "countingTask should finished") + assert.Equal(t, asynctask.StateFailed, errTskState, "error task should failed") assert.Equal(t, asynctask.StateFailed, panicTskState, "panic task should failed") + assert.Equal(t, asynctask.StateCompleted, completedTskState, "completed task should finished") // counting task do testing.Logf in another go routine // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 @@ -96,10 +102,9 @@ func TestWaitAllErrorCase(t *testing.T) { time.Sleep(1 * time.Millisecond) } -func TestWaitAllCanceled(t *testing.T) { +func TestWaitAllFailFastCancelingWait(t *testing.T) { t.Parallel() - ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second) - defer cancelFunc() + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 3*time.Second) start := time.Now() countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) @@ -108,18 +113,25 @@ func TestWaitAllCanceled(t *testing.T) { result := "something" completedTsk := asynctask.NewCompletedTask(&result) - waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond) - defer cancelFunc1() + waitCtx, cancelWait := context.WithTimeout(ctx, 5*time.Millisecond) + defer cancelWait() - elapsed := time.Since(start) err := asynctask.WaitAll(waitCtx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk) - - cancelFunc() // all assertion variable captured, cancel counting task + elapsed := time.Since(start) + countingTsk1State := countingTsk1.State() + countingTsk2State := countingTsk2.State() + countingTsk3State := countingTsk3.State() + completedTskState := completedTsk.State() + cancelTaskExecution() // all assertion variable captured, cancel task execution assert.Error(t, err) assert.True(t, errors.Is(err, context.DeadlineExceeded)) // should return before first task assert.True(t, elapsed < 10*2*time.Millisecond) + assert.Equal(t, countingTsk1State, asynctask.StateRunning) + assert.Equal(t, countingTsk2State, asynctask.StateRunning) + assert.Equal(t, countingTsk3State, asynctask.StateRunning) + assert.Equal(t, completedTskState, asynctask.StateCompleted) // counting task do testing.Logf in another go routine // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 @@ -127,6 +139,42 @@ func TestWaitAllCanceled(t *testing.T) { time.Sleep(1 * time.Millisecond) } +func TestWaitAllCancelingWait(t *testing.T) { + t.Parallel() + + ctx, cancelTaskExecution := newTestContextWithTimeout(t, 4*time.Millisecond) + + start := time.Now() + rcCtx, rcCancel := context.WithCancel(context.Background()) + uncontrollableTask := asynctask.Start(ctx, getUncontrollableTask(rcCtx, t)) + countingTsk1 := asynctask.Start(ctx, getCountingTask(10, "countingPer40ms", 40*time.Millisecond)) + countingTsk2 := asynctask.Start(ctx, getCountingTask(10, "countingPer20ms", 20*time.Millisecond)) + countingTsk3 := asynctask.Start(ctx, getCountingTask(10, "countingPer2ms", 2*time.Millisecond)) + result := "something" + completedTsk := asynctask.NewCompletedTask(&result) + + waitCtx, cancelWait := context.WithTimeout(ctx, 5*time.Millisecond) + defer cancelWait() + + err := asynctask.WaitAll(waitCtx, nil, countingTsk1, countingTsk2, countingTsk3, completedTsk, uncontrollableTask) + elapsed := time.Since(start) + t.Logf("WaitAll finished, elapsed: %v", elapsed) + cancelTaskExecution() // all assertion variable captured, cancel counting task + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) + // should return before first task + assert.True(t, elapsed < 10*2*time.Millisecond) + + // cancel the remote control context to stop the uncontrollable task, or goleak.VerifyNone will fail. + rcCancel() + + // counting task do testing.Logf in another go routine + // while testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 + // wait minor time for the go routine to finish. + time.Sleep(50 * time.Millisecond) +} + func TestWaitAllWithNoTasks(t *testing.T) { t.Parallel() ctx, cancelFunc := newTestContextWithTimeout(t, 1*time.Millisecond) @@ -135,3 +183,20 @@ func TestWaitAllWithNoTasks(t *testing.T) { err := asynctask.WaitAll(ctx, nil) assert.NoError(t, err) } + +// getUncontrollableTask return a task that is not honor context, it only hornor the remoteControl context. +func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncFunc[int] { + return func(ctx context.Context) (*int, error) { + for { + select { + case <-time.After(1 * time.Millisecond): + if err := ctx.Err(); err != nil { + t.Logf("[UncontrollableTask]: context %s, but not honoring it.", err) + } + case <-rcCtx.Done(): + t.Logf("[UncontrollableTask]: cancelled by remote control") + return nil, rcCtx.Err() + } + } + } +}