Skip to content

Commit

Permalink
Merge pull request #3 from Azure/haitao/handlesPointerError
Browse files Browse the repository at this point in the history
handles pointer error types
  • Loading branch information
haitch committed Apr 13, 2020
2 parents 0d4baf0 + e6d0031 commit d54bfec
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 16 deletions.
55 changes: 41 additions & 14 deletions async_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"
"runtime/debug"
"sync"
"time"
Expand Down Expand Up @@ -60,31 +61,38 @@ func (t *TaskStatus) State() State {
// Cancel abort the task execution
// !! only if the function provided handles context cancel.
func (t *TaskStatus) Cancel() {
t.cancelFunc()
if !t.state.IsTerminalState() {
t.cancelFunc()

t.finish(StateCanceled, nil, ErrCanceled)
t.finish(StateCanceled, nil, ErrCanceled)
}
}

// Wait block current thread/routine until task finished or failed.
func (t *TaskStatus) Wait() (interface{}, error) {
// skip the wait if task got canceled.
if !t.state.IsTerminalState() {
t.waitGroup.Wait()

// we create new context when starting task, now release it.
t.cancelFunc()
// return immediately if task already in terminal state.
if t.state.IsTerminalState() {
return t.result, t.err
}

// we create new context when starting task, now release it.
defer t.cancelFunc()

t.waitGroup.Wait()

return t.result, t.err
}

// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
func (t *TaskStatus) WaitWithTimeout(timeout time.Duration) (interface{}, error) {
defer t.cancelFunc()
// return immediately if task already in terminal state.
if t.state.IsTerminalState() {
return t.result, t.err
}

ch := make(chan interface{})
go func() {
t.waitGroup.Wait()
t.Wait()
close(ch)
}()

Expand All @@ -97,6 +105,18 @@ func (t *TaskStatus) WaitWithTimeout(timeout time.Duration) (interface{}, error)
}
}

// NewCompletedTask returns a Completed task, with result=nil, error=nil
func NewCompletedTask() *TaskStatus {
return &TaskStatus{
state: StateCompleted,
result: nil,
err: nil,
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
cancelFunc: nil,
waitGroup: nil,
}
}

// Start run a async function and returns you a handle which you can Wait or Cancel.
func Start(ctx context.Context, task AsyncFunc) *TaskStatus {
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -125,11 +145,18 @@ func runAndTrackTask(record *TaskStatus, task func(ctx context.Context) (interfa
}()

result, err := task(record)
if err != nil {
record.finish(StateFailed, result, err)
} else {
record.finish(StateCompleted, result, err)

if err == nil ||
// incase some team use pointer typed error (implement Error() string on a pointer type)
// which can break err check (but nil point assigned to error result to non-nil error)
// check out TestPointerErrorCase in error_test.go
reflect.ValueOf(err).IsNil() {
record.finish(StateCompleted, result, nil)
return
}

// err not nil, fail the task
record.finish(StateFailed, result, err)
}

func (t *TaskStatus) finish(state State, result interface{}, err error) {
Expand Down
38 changes: 36 additions & 2 deletions async_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,23 @@ func TestEasyCase(t *testing.T) {

rawResult, err := t1.Wait()
assert.NoError(t, err)

assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
result := rawResult.(int)
assert.Equal(t, result, 9)

//assert.Fail(t, "just want to see if trace is working")
// wait Again,
start := time.Now()
rawResult, err = t1.Wait()
elapsed := time.Since(start)
// nothing should change
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
result = rawResult.(int)
assert.Equal(t, result, 9)

assert.True(t, elapsed.Microseconds() < 2, "Second wait should take more than 2 millisecond")
}

func TestCancelFunc(t *testing.T) {
Expand All @@ -66,7 +76,14 @@ func TestCancelFunc(t *testing.T) {

rawResult, err := t1.Wait()
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Nil(t, rawResult)

// I can cancel again, and nothing changes
time.Sleep(time.Second * 1)
t1.Cancel()
rawResult, err = t1.Wait()
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Nil(t, rawResult)

Expand Down Expand Up @@ -135,6 +152,23 @@ func TestConsistentResultAfterTimeout(t *testing.T) {
assert.Nil(t, rawResult, "didn't expect resule on canceled task")
}

func TestCompletedTask(t *testing.T) {
t.Parallel()

tsk := asynctask.NewCompletedTask()
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should in CompletedState")

// nothing should happen
tsk.Cancel()
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")

// you get nil result and nil error
result, err := tsk.Wait()
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
assert.NoError(t, err)
assert.Nil(t, result)
}

func TestCrazyCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
Expand Down
28 changes: 28 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
"github.com/stretchr/testify/assert"
)

type pointerError struct{}

func (pe *pointerError) Error() string {
return "Error from pointer type"
}

func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
time.Sleep(sleepDuration)
Expand Down Expand Up @@ -49,3 +55,25 @@ func TestErrorCase(t *testing.T) {
assert.False(t, errors.Is(err, asynctask.ErrPanic), "not expecting ErrPanic")
assert.False(t, errors.Is(err, asynctask.ErrTimeout), "not expecting ErrTimeout")
}

func TestPointerErrorCase(t *testing.T) {
t.Parallel()

// nil point of a type that implement error
var pe *pointerError = nil
// pass this nil pointer to error interface
var err error = pe
// now you get a non-nil error
assert.False(t, err == nil, "reason this test is needed")

ctx := newTestContext(t)
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
time.Sleep(100 * time.Millisecond)
var pe *pointerError = nil
return "Done", pe
})

result, err := tsk.Wait()
assert.NoError(t, err)
assert.Equal(t, result, "Done")
}

0 comments on commit d54bfec

Please sign in to comment.