diff --git a/README.md b/README.md index cba7094..efe64d5 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,14 @@ Simple mimik of async/await for those come from C# world, so you don't need to dealing with waitGroup/channel in golang. +also the result is strongTyped with go generics, no type assertion is needed. + +few chaining method provided: +- ContinueWith: send task1's output to task2 as input, return reference to task2. +- AfterBoth : send output of taskA, taskB to taskC as input, return reference to taskC. +- WaitAll: all of the task have to finish to end the wait (with an option to fail early if any task failed) +- WaitAny: any of the task finish would end the wait + ```golang // start task task := asynctask.Start(ctx, countingTask) diff --git a/after_both.go b/after_both.go index 1aa33da..376ea3b 100644 --- a/after_both.go +++ b/after_both.go @@ -3,20 +3,21 @@ package asynctask import "context" // AfterBothFunc is a function that has 2 input. -type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error) +type AfterBothFunc[T, S, R any] func(context.Context, T, S) (R, error) // AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task. -// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running. +// +// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running. func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] { - return Start(ctx, func(fCtx context.Context) (*R, error) { + return Start(ctx, func(fCtx context.Context) (R, error) { t, err := tskT.Result(fCtx) if err != nil { - return nil, err + return *new(R), err } s, err := tskS.Result(fCtx) if err != nil { - return nil, err + return *new(R), err } return next(fCtx, t, s) @@ -24,10 +25,11 @@ func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], n } // AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface. -// Action is function that runs without return anything -// Func is function that runs and return something -func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) { - return func(ctx context.Context, t *T, s *S) (*interface{}, error) { +// +// Action is function that runs without return anything +// Func is function that runs and return something +func AfterBothActionToFunc[T, S any](action func(context.Context, T, S) error) func(context.Context, T, S) (interface{}, error) { + return func(ctx context.Context, t T, s S) (interface{}, error) { return nil, action(ctx, t, s) } } diff --git a/after_both_test.go b/after_both_test.go index ca0663c..6adf895 100644 --- a/after_both_test.go +++ b/after_both_test.go @@ -9,13 +9,13 @@ import ( "github.com/stretchr/testify/assert" ) -func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) { +func summarize2CountingTask(ctx context.Context, result1, result2 int) (int, error) { t := ctx.Value(testContextKey).(*testing.T) t.Logf("result1: %d", result1) t.Logf("result2: %d", result2) - sum := *result1 + *result2 + sum := result1 + result2 t.Logf("sum: %d", sum) - return &sum, nil + return sum, nil } func TestAfterBoth(t *testing.T) { @@ -28,7 +28,7 @@ func TestAfterBoth(t *testing.T) { sum, err := t3.Result(ctx) assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error") - assert.Equal(t, *sum, 18, "Sum should be 18") + assert.Equal(t, sum, 18, "Sum should be 18") } func TestAfterBothFailureCase(t *testing.T) { @@ -56,11 +56,11 @@ func TestAfterBothActionToFunc(t *testing.T) { countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond)) countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond)) - t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error { + t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 int) error { t := ctx.Value(testContextKey).(*testing.T) t.Logf("result1: %d", result1) t.Logf("result2: %d", result2) - sum := *result1 + *result2 + sum := result1 + result2 t.Logf("sum: %d", sum) return nil })) diff --git a/continue_with.go b/continue_with.go index 0e6ab9d..5594848 100644 --- a/continue_with.go +++ b/continue_with.go @@ -3,23 +3,24 @@ package asynctask import "context" // ContinueFunc is a function that can be connected to previous task with ContinueWith -type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error) +type ContinueFunc[T any, S any] func(context.Context, T) (S, error) func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] { - return Start(ctx, func(fCtx context.Context) (*S, error) { + return Start(ctx, func(fCtx context.Context) (S, error) { result, err := tsk.Result(fCtx) if err != nil { - return nil, err + return *new(S), err } return next(fCtx, result) }) } // ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface. -// Action is function that runs without return anything -// Func is function that runs and return something -func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) { - return func(ctx context.Context, t *T) (*interface{}, error) { +// +// Action is function that runs without return anything +// Func is function that runs and return something +func ContinueActionToFunc[T any](action func(context.Context, T) error) func(context.Context, T) (interface{}, error) { + return func(ctx context.Context, t T) (interface{}, error) { return nil, action(ctx, t) } } diff --git a/continue_with_test.go b/continue_with_test.go index 1534be0..603fe28 100644 --- a/continue_with_test.go +++ b/continue_with_test.go @@ -11,7 +11,7 @@ import ( ) func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] { - return func(ctx context.Context) (*int, error) { + return func(ctx context.Context) (int, error) { t := ctx.Value(testContextKey).(*testing.T) result := countFrom @@ -22,10 +22,10 @@ func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duratio result++ case <-ctx.Done(): t.Log("work canceled") - return &result, nil + return result, nil } } - return &result, nil + return result, nil } } @@ -33,36 +33,36 @@ func TestContinueWith(t *testing.T) { t.Parallel() ctx := newTestContext(t) t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond)) - t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) { - fromPrevTsk := *input + t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) { + fromPrevTsk := input return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx) }) - t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) { - fromPrevTsk := *input + t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) { + fromPrevTsk := input return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx) }) result, err := t2.Result(ctx) assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error") - assert.Equal(t, *result, 20) + assert.Equal(t, result, 20) result, err = t3.Result(ctx) assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error") - assert.Equal(t, *result, 22) + assert.Equal(t, result, 22) } func TestContinueWithFailureCase(t *testing.T) { t.Parallel() ctx := newTestContext(t) t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond)) - t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) { - fromPrevTsk := *input + t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) { + fromPrevTsk := input return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx) }) - t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) { - fromPrevTsk := *input + t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) { + fromPrevTsk := input return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx) }) diff --git a/error_test.go b/error_test.go index b4f43b2..b78267d 100644 --- a/error_test.go +++ b/error_test.go @@ -11,16 +11,16 @@ import ( ) func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] { - return func(ctx context.Context) (*string, error) { + return func(ctx context.Context) (string, error) { time.Sleep(sleepDuration) panic("yo") } } func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] { - return func(ctx context.Context) (*int, error) { + return func(ctx context.Context) (int, error) { time.Sleep(sleepDuration) - return nil, errors.New(errorString) + return 0, errors.New(errorString) } } @@ -37,12 +37,12 @@ func TestTimeoutCase(t *testing.T) { // I can continue wait with longer time rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second) assert.NoError(t, err) - assert.Equal(t, 9, *rawResult) + assert.Equal(t, 9, rawResult) // any following Wait should complete immediately rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond) assert.NoError(t, err) - assert.Equal(t, 9, *rawResult) + assert.Equal(t, 9, rawResult) } func TestPanicCase(t *testing.T) { diff --git a/go.mod b/go.mod index efb7121..5c54d80 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/Azure/go-asynctask -go 1.19 +go 1.20 require github.com/stretchr/testify v1.8.4 diff --git a/task.go b/task.go index 1c31861..c30b38e 100644 --- a/task.go +++ b/task.go @@ -9,13 +9,13 @@ import ( ) // AsyncFunc is a function interface this asyncTask accepts. -type AsyncFunc[T any] func(context.Context) (*T, error) +type AsyncFunc[T any] func(context.Context) (T, error) // ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface. // - Action is function that runs without return anything // - Func is function that runs and return something -func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) { - return func(ctx context.Context) (*interface{}, error) { +func ActionToFunc(action func(context.Context) error) func(context.Context) (interface{}, error) { + return func(ctx context.Context) (interface{}, error) { return nil, action(ctx) } } @@ -24,17 +24,17 @@ func ActionToFunc(action func(context.Context) error) func(context.Context) (*in // which you can use to wait, cancel, get the result. type Task[T any] struct { state State - result *T + result T err error cancelFunc context.CancelFunc waitGroup *sync.WaitGroup - mutex *sync.Mutex + mutex *sync.RWMutex } // State return state of the task. func (t *Task[T]) State() State { - t.mutex.Lock() - defer t.mutex.Unlock() + t.mutex.RLock() + defer t.mutex.RUnlock() return t.state } @@ -42,7 +42,7 @@ func (t *Task[T]) State() State { // !! this rely on the task function to check context cancellation and proper context handling. func (t *Task[T]) Cancel() bool { if !t.finished() { - t.finish(StateCanceled, nil, ErrCanceled) + t.finish(StateCanceled, *new(T), ErrCanceled) return true } @@ -74,7 +74,7 @@ func (t *Task[T]) Wait(ctx context.Context) error { // WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified. // timeout only stop waiting, taks will remain running. -func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) { +func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (T, error) { // return immediately if task already in terminal state. if t.finished() { return t.result, t.err @@ -86,11 +86,10 @@ func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (* return t.Result(ctx) } -func (t *Task[T]) Result(ctx context.Context) (*T, error) { +func (t *Task[T]) Result(ctx context.Context) (T, error) { err := t.Wait(ctx) if err != nil { - var result T - return &result, err + return *new(T), err } return t.result, t.err @@ -102,11 +101,11 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] { ctx, cancel := context.WithCancel(ctx) wg := &sync.WaitGroup{} wg.Add(1) - mutex := &sync.Mutex{} + mutex := &sync.RWMutex{} record := &Task[T]{ state: StateRunning, - result: nil, + result: *new(T), cancelFunc: cancel, waitGroup: wg, mutex: mutex, @@ -118,7 +117,7 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] { } // NewCompletedTask returns a Completed task, with result=nil, error=nil -func NewCompletedTask[T any](value *T) *Task[T] { +func NewCompletedTask[T any](value T) *Task[T] { return &Task[T]{ state: StateCompleted, result: value, @@ -126,16 +125,16 @@ func NewCompletedTask[T any](value *T) *Task[T] { // nil cancelFunc and waitGroup should be protected with IsTerminalState() cancelFunc: nil, waitGroup: nil, - mutex: &sync.Mutex{}, + mutex: &sync.RWMutex{}, } } -func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) { +func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (T, error)) { defer record.waitGroup.Done() defer func() { if r := recover(); r != nil { err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic) - record.finish(StateFailed, nil, err) + record.finish(StateFailed, *new(T), err) } }() @@ -150,7 +149,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu record.finish(StateFailed, result, err) } -func (t *Task[T]) finish(state State, result *T, err error) { +func (t *Task[T]) finish(state State, result T, err error) { // only update state and result if not yet canceled t.mutex.Lock() defer t.mutex.Unlock() @@ -163,7 +162,7 @@ func (t *Task[T]) finish(state State, result *T, err error) { } func (t *Task[T]) finished() bool { - t.mutex.Lock() - defer t.mutex.Unlock() + t.mutex.RLock() + defer t.mutex.RUnlock() return t.state.IsTerminalState() } diff --git a/task_test.go b/task_test.go index ba6f7c5..cd91a85 100644 --- a/task_test.go +++ b/task_test.go @@ -22,7 +22,7 @@ func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Con } func getCountingTask(countTo int, taskId string, sleepInterval time.Duration) asynctask.AsyncFunc[int] { - return func(ctx context.Context) (*int, error) { + return func(ctx context.Context) (int, error) { t := ctx.Value(testContextKey).(*testing.T) result := 0 @@ -35,10 +35,10 @@ func getCountingTask(countTo int, taskId string, sleepInterval time.Duration) as // testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343 // leave minor time buffer before exit test to finish this last logging at least. t.Logf("[%s]: work canceled", taskId) - return &result, nil + return result, nil } } - return &result, nil + return result, nil } } @@ -54,7 +54,7 @@ func TestEasyGenericCase(t *testing.T) { assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now") assert.NotNil(t, rawResult) - assert.Equal(t, *rawResult, 9) + assert.Equal(t, rawResult, 9) // wait Again, start := time.Now() @@ -64,7 +64,7 @@ func TestEasyGenericCase(t *testing.T) { assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now") assert.NotNil(t, rawResult) - assert.Equal(t, *rawResult, 9) + assert.Equal(t, rawResult, 9) // Result should be returned immediately assert.True(t, elapsed.Milliseconds() < 1, fmt.Sprintf("Second wait should have return immediately: %s", elapsed)) @@ -121,13 +121,13 @@ func TestConsistentResultAfterCancelGenericTask(t *testing.T) { rawResult, err := t2.Result(ctx) assert.NoError(t, err) assert.Equal(t, asynctask.StateCompleted, t2.State(), "t2 should complete") - assert.Equal(t, *rawResult, 9) + assert.Equal(t, rawResult, 9) // t1 should remain canceled and rawResult, err = t1.Result(ctx) assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error") assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state") - assert.Equal(t, *rawResult, 0) // default value for int + assert.Equal(t, rawResult, 0) // default value for int } func TestCompletedGenericTask(t *testing.T) { @@ -187,10 +187,10 @@ func TestCrazyCaseGeneric(t *testing.T) { if i%2 == 0 { assert.Equal(t, asynctask.ErrCanceled, err, fmt.Sprintf("task %s should be canceled, but it finished with %+v", fmt.Sprintf("CrazyTask%d", i), rawResult)) - assert.Equal(t, *rawResult, 0) + assert.Equal(t, rawResult, 0) } else { assert.NoError(t, err) - assert.Equal(t, *rawResult, 9) + assert.Equal(t, rawResult, 9) } } } diff --git a/wait_all_test.go b/wait_all_test.go index 1231cb8..bebeb39 100644 --- a/wait_all_test.go +++ b/wait_all_test.go @@ -186,7 +186,7 @@ func TestWaitAllWithNoTasks(t *testing.T) { // getUncontrollableTask return a task that is not honor context, it only hornor the remoteControl context. func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncFunc[int] { - return func(ctx context.Context) (*int, error) { + return func(ctx context.Context) (int, error) { for { select { case <-time.After(1 * time.Millisecond): @@ -195,7 +195,7 @@ func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncF } case <-rcCtx.Done(): t.Logf("[UncontrollableTask]: cancelled by remote control") - return nil, rcCtx.Err() + return 0, rcCtx.Err() } } }