Skip to content

Commit

Permalink
remove restriction on return pointer value of TypeParameter (#33)
Browse files Browse the repository at this point in the history
* remove pointer bind

* Change sync.Mutex to sync.RWMutex in Task struct

* Update go.mod to use go 1.20

* tweaks
  • Loading branch information
haitch committed Nov 13, 2023
1 parent 4c78f32 commit 442a02a
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 73 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

Simple mimik of async/await for those come from C# world, so you don't need to dealing with waitGroup/channel in golang.

also the result is strongTyped with go generics, no type assertion is needed.

few chaining method provided:
- ContinueWith: send task1's output to task2 as input, return reference to task2.
- AfterBoth : send output of taskA, taskB to taskC as input, return reference to taskC.
- WaitAll: all of the task have to finish to end the wait (with an option to fail early if any task failed)
- WaitAny: any of the task finish would end the wait

```golang
// start task
task := asynctask.Start(ctx, countingTask)
Expand Down
20 changes: 11 additions & 9 deletions after_both.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,33 @@ package asynctask
import "context"

// AfterBothFunc is a function that has 2 input.
type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error)
type AfterBothFunc[T, S, R any] func(context.Context, T, S) (R, error)

// AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task.
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
//
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] {
return Start(ctx, func(fCtx context.Context) (*R, error) {
return Start(ctx, func(fCtx context.Context) (R, error) {
t, err := tskT.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}

s, err := tskS.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}

return next(fCtx, t, s)
})
}

// AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) {
return func(ctx context.Context, t *T, s *S) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, T, S) error) func(context.Context, T, S) (interface{}, error) {
return func(ctx context.Context, t T, s S) (interface{}, error) {
return nil, action(ctx, t, s)
}
}
12 changes: 6 additions & 6 deletions after_both_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"github.com/stretchr/testify/assert"
)

func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) {
func summarize2CountingTask(ctx context.Context, result1, result2 int) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return &sum, nil
return sum, nil
}

func TestAfterBoth(t *testing.T) {
Expand All @@ -28,7 +28,7 @@ func TestAfterBoth(t *testing.T) {
sum, err := t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *sum, 18, "Sum should be 18")
assert.Equal(t, sum, 18, "Sum should be 18")
}

func TestAfterBothFailureCase(t *testing.T) {
Expand Down Expand Up @@ -56,11 +56,11 @@ func TestAfterBothActionToFunc(t *testing.T) {

countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond))
countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond))
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error {
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 int) error {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return nil
}))
Expand Down
15 changes: 8 additions & 7 deletions continue_with.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@ package asynctask
import "context"

// ContinueFunc is a function that can be connected to previous task with ContinueWith
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
type ContinueFunc[T any, S any] func(context.Context, T) (S, error)

func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
return Start(ctx, func(fCtx context.Context) (*S, error) {
return Start(ctx, func(fCtx context.Context) (S, error) {
result, err := tsk.Result(fCtx)
if err != nil {
return nil, err
return *new(S), err
}
return next(fCtx, result)
})
}

// ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) {
return func(ctx context.Context, t *T) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, T) error) func(context.Context, T) (interface{}, error) {
return func(ctx context.Context, t T) (interface{}, error) {
return nil, action(ctx, t)
}
}
26 changes: 13 additions & 13 deletions continue_with_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)

result := countFrom
Expand All @@ -22,47 +22,47 @@ func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duratio
result++
case <-ctx.Done():
t.Log("work canceled")
return &result, nil
return result, nil
}
}
return &result, nil
return result, nil
}
}

func TestContinueWith(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})

result, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
assert.Equal(t, *result, 20)
assert.Equal(t, result, 20)

result, err = t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *result, 22)
assert.Equal(t, result, 22)
}

func TestContinueWithFailureCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})

Expand Down
10 changes: 5 additions & 5 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ import (
)

func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
return func(ctx context.Context) (*string, error) {
return func(ctx context.Context) (string, error) {
time.Sleep(sleepDuration)
panic("yo")
}
}

func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
time.Sleep(sleepDuration)
return nil, errors.New(errorString)
return 0, errors.New(errorString)
}
}

Expand All @@ -37,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
// I can continue wait with longer time
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)

// any following Wait should complete immediately
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)
}

func TestPanicCase(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/Azure/go-asynctask

go 1.19
go 1.20

require github.com/stretchr/testify v1.8.4

Expand Down
41 changes: 20 additions & 21 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
)

// AsyncFunc is a function interface this asyncTask accepts.
type AsyncFunc[T any] func(context.Context) (*T, error)
type AsyncFunc[T any] func(context.Context) (T, error)

// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// - Action is function that runs without return anything
// - Func is function that runs and return something
func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) {
return func(ctx context.Context) (*interface{}, error) {
func ActionToFunc(action func(context.Context) error) func(context.Context) (interface{}, error) {
return func(ctx context.Context) (interface{}, error) {
return nil, action(ctx)
}
}
Expand All @@ -24,25 +24,25 @@ func ActionToFunc(action func(context.Context) error) func(context.Context) (*in
// which you can use to wait, cancel, get the result.
type Task[T any] struct {
state State
result *T
result T
err error
cancelFunc context.CancelFunc
waitGroup *sync.WaitGroup
mutex *sync.Mutex
mutex *sync.RWMutex
}

// State return state of the task.
func (t *Task[T]) State() State {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state
}

// Cancel the task by cancel the context.
// !! this rely on the task function to check context cancellation and proper context handling.
func (t *Task[T]) Cancel() bool {
if !t.finished() {
t.finish(StateCanceled, nil, ErrCanceled)
t.finish(StateCanceled, *new(T), ErrCanceled)
return true
}

Expand Down Expand Up @@ -74,7 +74,7 @@ func (t *Task[T]) Wait(ctx context.Context) error {

// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
// timeout only stop waiting, taks will remain running.
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (T, error) {
// return immediately if task already in terminal state.
if t.finished() {
return t.result, t.err
Expand All @@ -86,11 +86,10 @@ func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*
return t.Result(ctx)
}

func (t *Task[T]) Result(ctx context.Context) (*T, error) {
func (t *Task[T]) Result(ctx context.Context) (T, error) {
err := t.Wait(ctx)
if err != nil {
var result T
return &result, err
return *new(T), err
}

return t.result, t.err
Expand All @@ -102,11 +101,11 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
ctx, cancel := context.WithCancel(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
mutex := &sync.Mutex{}
mutex := &sync.RWMutex{}

record := &Task[T]{
state: StateRunning,
result: nil,
result: *new(T),
cancelFunc: cancel,
waitGroup: wg,
mutex: mutex,
Expand All @@ -118,24 +117,24 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
}

// NewCompletedTask returns a Completed task, with result=nil, error=nil
func NewCompletedTask[T any](value *T) *Task[T] {
func NewCompletedTask[T any](value T) *Task[T] {
return &Task[T]{
state: StateCompleted,
result: value,
err: nil,
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
cancelFunc: nil,
waitGroup: nil,
mutex: &sync.Mutex{},
mutex: &sync.RWMutex{},
}
}

func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (T, error)) {
defer record.waitGroup.Done()
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic)
record.finish(StateFailed, nil, err)
record.finish(StateFailed, *new(T), err)
}
}()

Expand All @@ -150,7 +149,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu
record.finish(StateFailed, result, err)
}

func (t *Task[T]) finish(state State, result *T, err error) {
func (t *Task[T]) finish(state State, result T, err error) {
// only update state and result if not yet canceled
t.mutex.Lock()
defer t.mutex.Unlock()
Expand All @@ -163,7 +162,7 @@ func (t *Task[T]) finish(state State, result *T, err error) {
}

func (t *Task[T]) finished() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state.IsTerminalState()
}
Loading

0 comments on commit 442a02a

Please sign in to comment.