From 4621b563cf6bbb386c5bace7670accec5baba04a Mon Sep 17 00:00:00 2001 From: Haitao Chen Date: Thu, 27 Oct 2022 15:55:38 -0700 Subject: [PATCH] after both, actionToFunc (#21) * try afterBoth * update after_both * ActionToFunc * more tweaks * remove afterAll, ready to merge to main * code coverage * update comments --- after_both.go | 33 +++++++++++++++++++++ after_both_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++ continue_with.go | 9 ++++++ continue_with_test.go | 16 ++++++++++ task.go | 9 ++++++ task_test.go | 14 +++++++++ 6 files changed, 150 insertions(+) create mode 100644 after_both.go create mode 100644 after_both_test.go diff --git a/after_both.go b/after_both.go new file mode 100644 index 0000000..1aa33da --- /dev/null +++ b/after_both.go @@ -0,0 +1,33 @@ +package asynctask + +import "context" + +// AfterBothFunc is a function that has 2 input. +type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error) + +// AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task. +// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running. +func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] { + return Start(ctx, func(fCtx context.Context) (*R, error) { + t, err := tskT.Result(fCtx) + if err != nil { + return nil, err + } + + s, err := tskS.Result(fCtx) + if err != nil { + return nil, err + } + + return next(fCtx, t, s) + }) +} + +// AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface. +// Action is function that runs without return anything +// Func is function that runs and return something +func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) { + return func(ctx context.Context, t *T, s *S) (*interface{}, error) { + return nil, action(ctx, t, s) + } +} diff --git a/after_both_test.go b/after_both_test.go new file mode 100644 index 0000000..ca0663c --- /dev/null +++ b/after_both_test.go @@ -0,0 +1,69 @@ +package asynctask_test + +import ( + "context" + "testing" + "time" + + "github.com/Azure/go-asynctask" + "github.com/stretchr/testify/assert" +) + +func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) { + t := ctx.Value(testContextKey).(*testing.T) + t.Logf("result1: %d", result1) + t.Logf("result2: %d", result2) + sum := *result1 + *result2 + t.Logf("sum: %d", sum) + return &sum, nil +} + +func TestAfterBoth(t *testing.T) { + t.Parallel() + ctx := newTestContext(t) + t1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond)) + t2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond)) + t3 := asynctask.AfterBoth(ctx, t1, t2, summarize2CountingTask) + + sum, err := t3.Result(ctx) + assert.NoError(t, err) + assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error") + assert.Equal(t, *sum, 18, "Sum should be 18") +} + +func TestAfterBothFailureCase(t *testing.T) { + t.Parallel() + ctx := newTestContext(t) + errTask := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond)) + countingTask := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond)) + + task1Err := asynctask.AfterBoth(ctx, errTask, countingTask, summarize2CountingTask) + _, err := task1Err.Result(ctx) + assert.Error(t, err) + + task2Err := asynctask.AfterBoth(ctx, errTask, countingTask, summarize2CountingTask) + _, err = task2Err.Result(ctx) + assert.Error(t, err) + + task3NoErr := asynctask.AfterBoth(ctx, countingTask, countingTask, summarize2CountingTask) + _, err = task3NoErr.Result(ctx) + assert.NoError(t, err) +} + +func TestAfterBothActionToFunc(t *testing.T) { + t.Parallel() + ctx := newTestContext(t) + + countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond)) + countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond)) + t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error { + t := ctx.Value(testContextKey).(*testing.T) + t.Logf("result1: %d", result1) + t.Logf("result2: %d", result2) + sum := *result1 + *result2 + t.Logf("sum: %d", sum) + return nil + })) + _, err := t2.Result(ctx) + assert.NoError(t, err) +} diff --git a/continue_with.go b/continue_with.go index e6df119..0e6ab9d 100644 --- a/continue_with.go +++ b/continue_with.go @@ -14,3 +14,12 @@ func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next Continue return next(fCtx, result) }) } + +// ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface. +// Action is function that runs without return anything +// Func is function that runs and return something +func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) { + return func(ctx context.Context, t *T) (*interface{}, error) { + return nil, action(ctx, t) + } +} diff --git a/continue_with_test.go b/continue_with_test.go index 380c167..1534be0 100644 --- a/continue_with_test.go +++ b/continue_with_test.go @@ -2,6 +2,7 @@ package asynctask_test import ( "context" + "fmt" "testing" "time" @@ -75,3 +76,18 @@ func TestContinueWithFailureCase(t *testing.T) { assert.Equal(t, asynctask.StateFailed, t3.State(), "Task3 should fail since preceeding task failed") assert.Equal(t, "devide by 0", err.Error()) } + +func TestContinueActionToFunc(t *testing.T) { + t.Parallel() + ctx := newTestContext(t) + + t1 := asynctask.Start(ctx, func(ctx context.Context) (*int, error) { i := 0; return &i, nil }) + t2 := asynctask.ContinueWith(ctx, t1, asynctask.ContinueActionToFunc(func(ctx context.Context, i *int) error { + if *i != 0 { + return fmt.Errorf("input should be 0, but got %d", i) + } + return nil + })) + _, err := t2.Result(ctx) + assert.NoError(t, err) +} diff --git a/task.go b/task.go index de48898..699c256 100644 --- a/task.go +++ b/task.go @@ -11,6 +11,15 @@ import ( // AsyncFunc is a function interface this asyncTask accepts. type AsyncFunc[T any] func(context.Context) (*T, error) +// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface. +// Action is function that runs without return anything +// Func is function that runs and return something +func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) { + return func(ctx context.Context) (*interface{}, error) { + return nil, action(ctx) + } +} + // Task is a handle to the running function. // which you can use to wait, cancel, get the result. type Task[T any] struct { diff --git a/task_test.go b/task_test.go index d21613c..ba6f7c5 100644 --- a/task_test.go +++ b/task_test.go @@ -150,6 +150,20 @@ func TestCompletedGenericTask(t *testing.T) { assert.Equal(t, *resultGet, result) } +func TestActionToFunc(t *testing.T) { + t.Parallel() + + action := func(ctx context.Context) error { + return nil + } + + ctx := context.Background() + task := asynctask.Start(ctx, asynctask.ActionToFunc(action)) + result, err := task.Result(ctx) + assert.NoError(t, err) + assert.Nil(t, result) +} + func TestCrazyCaseGeneric(t *testing.T) { t.Parallel() ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)