Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove restriction on return pointer value of TypeParameter #33

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

Simple mimik of async/await for those come from C# world, so you don't need to dealing with waitGroup/channel in golang.

also the result is strongTyped with go generics, no type assertion is needed.

few chaining method provided:
- ContinueWith: send task1's output to task2 as input, return reference to task2.
- AfterBoth : send output of taskA, taskB to taskC as input, return reference to taskC.
- WaitAll: all of the task have to finish to end the wait (with an option to fail early if any task failed)
- WaitAny: any of the task finish would end the wait

```golang
// start task
task := asynctask.Start(ctx, countingTask)
Expand Down
20 changes: 11 additions & 9 deletions after_both.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,33 @@ package asynctask
import "context"

// AfterBothFunc is a function that has 2 input.
type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error)
type AfterBothFunc[T, S, R any] func(context.Context, T, S) (R, error)

// AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task.
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
//
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] {
return Start(ctx, func(fCtx context.Context) (*R, error) {
return Start(ctx, func(fCtx context.Context) (R, error) {
t, err := tskT.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}

s, err := tskS.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}

return next(fCtx, t, s)
})
}

// AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) {
return func(ctx context.Context, t *T, s *S) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, T, S) error) func(context.Context, T, S) (interface{}, error) {
return func(ctx context.Context, t T, s S) (interface{}, error) {
return nil, action(ctx, t, s)
}
}
12 changes: 6 additions & 6 deletions after_both_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"github.com/stretchr/testify/assert"
)

func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) {
func summarize2CountingTask(ctx context.Context, result1, result2 int) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return &sum, nil
return sum, nil
}

func TestAfterBoth(t *testing.T) {
Expand All @@ -28,7 +28,7 @@ func TestAfterBoth(t *testing.T) {
sum, err := t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *sum, 18, "Sum should be 18")
assert.Equal(t, sum, 18, "Sum should be 18")
}

func TestAfterBothFailureCase(t *testing.T) {
Expand Down Expand Up @@ -56,11 +56,11 @@ func TestAfterBothActionToFunc(t *testing.T) {

countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond))
countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond))
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error {
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 int) error {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return nil
}))
Expand Down
15 changes: 8 additions & 7 deletions continue_with.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@ package asynctask
import "context"

// ContinueFunc is a function that can be connected to previous task with ContinueWith
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
type ContinueFunc[T any, S any] func(context.Context, T) (S, error)

func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
return Start(ctx, func(fCtx context.Context) (*S, error) {
return Start(ctx, func(fCtx context.Context) (S, error) {
result, err := tsk.Result(fCtx)
if err != nil {
return nil, err
return *new(S), err
}
return next(fCtx, result)
})
}

// ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) {
return func(ctx context.Context, t *T) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, T) error) func(context.Context, T) (interface{}, error) {
return func(ctx context.Context, t T) (interface{}, error) {
return nil, action(ctx, t)
}
}
26 changes: 13 additions & 13 deletions continue_with_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)

result := countFrom
Expand All @@ -22,47 +22,47 @@ func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duratio
result++
case <-ctx.Done():
t.Log("work canceled")
return &result, nil
return result, nil
}
}
return &result, nil
return result, nil
}
}

func TestContinueWith(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})

result, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
assert.Equal(t, *result, 20)
assert.Equal(t, result, 20)

result, err = t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *result, 22)
assert.Equal(t, result, 22)
}

func TestContinueWithFailureCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})

Expand Down
10 changes: 5 additions & 5 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ import (
)

func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
return func(ctx context.Context) (*string, error) {
return func(ctx context.Context) (string, error) {
time.Sleep(sleepDuration)
panic("yo")
}
}

func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
time.Sleep(sleepDuration)
return nil, errors.New(errorString)
return 0, errors.New(errorString)
}
}

Expand All @@ -37,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
// I can continue wait with longer time
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)

// any following Wait should complete immediately
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)
}

func TestPanicCase(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/Azure/go-asynctask

go 1.19
go 1.20

require github.com/stretchr/testify v1.8.4

Expand Down
41 changes: 20 additions & 21 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
)

// AsyncFunc is a function interface this asyncTask accepts.
type AsyncFunc[T any] func(context.Context) (*T, error)
type AsyncFunc[T any] func(context.Context) (T, error)

// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// - Action is function that runs without return anything
// - Func is function that runs and return something
func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) {
return func(ctx context.Context) (*interface{}, error) {
func ActionToFunc(action func(context.Context) error) func(context.Context) (interface{}, error) {
return func(ctx context.Context) (interface{}, error) {
return nil, action(ctx)
}
}
Expand All @@ -24,25 +24,25 @@ func ActionToFunc(action func(context.Context) error) func(context.Context) (*in
// which you can use to wait, cancel, get the result.
type Task[T any] struct {
state State
result *T
result T
err error
cancelFunc context.CancelFunc
waitGroup *sync.WaitGroup
mutex *sync.Mutex
mutex *sync.RWMutex
}

// State return state of the task.
func (t *Task[T]) State() State {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state
}

// Cancel the task by cancel the context.
// !! this rely on the task function to check context cancellation and proper context handling.
func (t *Task[T]) Cancel() bool {
if !t.finished() {
t.finish(StateCanceled, nil, ErrCanceled)
t.finish(StateCanceled, *new(T), ErrCanceled)
return true
}

Expand Down Expand Up @@ -74,7 +74,7 @@ func (t *Task[T]) Wait(ctx context.Context) error {

// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
// timeout only stop waiting, taks will remain running.
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (T, error) {
// return immediately if task already in terminal state.
if t.finished() {
return t.result, t.err
Expand All @@ -86,11 +86,10 @@ func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*
return t.Result(ctx)
}

func (t *Task[T]) Result(ctx context.Context) (*T, error) {
func (t *Task[T]) Result(ctx context.Context) (T, error) {
err := t.Wait(ctx)
if err != nil {
var result T
return &result, err
return *new(T), err
}

return t.result, t.err
Expand All @@ -102,11 +101,11 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
ctx, cancel := context.WithCancel(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
mutex := &sync.Mutex{}
mutex := &sync.RWMutex{}

record := &Task[T]{
state: StateRunning,
result: nil,
result: *new(T),
cancelFunc: cancel,
waitGroup: wg,
mutex: mutex,
Expand All @@ -118,24 +117,24 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
}

// NewCompletedTask returns a Completed task, with result=nil, error=nil
func NewCompletedTask[T any](value *T) *Task[T] {
func NewCompletedTask[T any](value T) *Task[T] {
return &Task[T]{
state: StateCompleted,
result: value,
err: nil,
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
cancelFunc: nil,
waitGroup: nil,
mutex: &sync.Mutex{},
mutex: &sync.RWMutex{},
}
}

func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (T, error)) {
defer record.waitGroup.Done()
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic)
record.finish(StateFailed, nil, err)
record.finish(StateFailed, *new(T), err)
}
}()

Expand All @@ -150,7 +149,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu
record.finish(StateFailed, result, err)
}

func (t *Task[T]) finish(state State, result *T, err error) {
func (t *Task[T]) finish(state State, result T, err error) {
// only update state and result if not yet canceled
t.mutex.Lock()
defer t.mutex.Unlock()
Expand All @@ -163,7 +162,7 @@ func (t *Task[T]) finish(state State, result *T, err error) {
}

func (t *Task[T]) finished() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state.IsTerminalState()
}
Loading
Loading