remove restriction on return pointer value of TypeParameter (#33)

* remove pointer bind

* Change sync.Mutex to sync.RWMutex in Task struct

* Update go.mod to use go 1.20

* tweaks
This commit is contained in:
Haitao Chen 2023-11-13 14:08:57 -08:00 коммит произвёл GitHub
Родитель 4c78f32854
Коммит 442a02a75b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 83 добавлений и 73 удалений

Просмотреть файл

@ -7,6 +7,14 @@
Simple mimik of async/await for those come from C# world, so you don't need to dealing with waitGroup/channel in golang.
also the result is strongTyped with go generics, no type assertion is needed.
few chaining method provided:
- ContinueWith: send task1's output to task2 as input, return reference to task2.
- AfterBoth : send output of taskA, taskB to taskC as input, return reference to taskC.
- WaitAll: all of the task have to finish to end the wait (with an option to fail early if any task failed)
- WaitAny: any of the task finish would end the wait
```golang
// start task
task := asynctask.Start(ctx, countingTask)

Просмотреть файл

@ -3,20 +3,21 @@ package asynctask
import "context"
// AfterBothFunc is a function that has 2 input.
type AfterBothFunc[T, S, R any] func(context.Context, *T, *S) (*R, error)
type AfterBothFunc[T, S, R any] func(context.Context, T, S) (R, error)
// AfterBoth runs the function after both 2 input task finished, and will be fed with result from 2 input task.
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
//
// if one of the input task failed, the AfterBoth task will be failed and returned, even other one are still running.
func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], next AfterBothFunc[T, S, R]) *Task[R] {
return Start(ctx, func(fCtx context.Context) (*R, error) {
return Start(ctx, func(fCtx context.Context) (R, error) {
t, err := tskT.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}
s, err := tskS.Result(fCtx)
if err != nil {
return nil, err
return *new(R), err
}
return next(fCtx, t, s)
@ -24,10 +25,11 @@ func AfterBoth[T, S, R any](ctx context.Context, tskT *Task[T], tskS *Task[S], n
}
// AfterBothActionToFunc convert a Action to Func (C# term), to satisfy the AfterBothFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, *T, *S) error) func(context.Context, *T, *S) (*interface{}, error) {
return func(ctx context.Context, t *T, s *S) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func AfterBothActionToFunc[T, S any](action func(context.Context, T, S) error) func(context.Context, T, S) (interface{}, error) {
return func(ctx context.Context, t T, s S) (interface{}, error) {
return nil, action(ctx, t, s)
}
}

Просмотреть файл

@ -9,13 +9,13 @@ import (
"github.com/stretchr/testify/assert"
)
func summarize2CountingTask(ctx context.Context, result1, result2 *int) (*int, error) {
func summarize2CountingTask(ctx context.Context, result1, result2 int) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return &sum, nil
return sum, nil
}
func TestAfterBoth(t *testing.T) {
@ -28,7 +28,7 @@ func TestAfterBoth(t *testing.T) {
sum, err := t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *sum, 18, "Sum should be 18")
assert.Equal(t, sum, 18, "Sum should be 18")
}
func TestAfterBothFailureCase(t *testing.T) {
@ -56,11 +56,11 @@ func TestAfterBothActionToFunc(t *testing.T) {
countingTask1 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P1", 20*time.Millisecond))
countingTask2 := asynctask.Start(ctx, getCountingTask(10, "afterboth.P2", 20*time.Millisecond))
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 *int) error {
t2 := asynctask.AfterBoth(ctx, countingTask1, countingTask2, asynctask.AfterBothActionToFunc(func(ctx context.Context, result1, result2 int) error {
t := ctx.Value(testContextKey).(*testing.T)
t.Logf("result1: %d", result1)
t.Logf("result2: %d", result2)
sum := *result1 + *result2
sum := result1 + result2
t.Logf("sum: %d", sum)
return nil
}))

Просмотреть файл

@ -3,23 +3,24 @@ package asynctask
import "context"
// ContinueFunc is a function that can be connected to previous task with ContinueWith
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
type ContinueFunc[T any, S any] func(context.Context, T) (S, error)
func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
return Start(ctx, func(fCtx context.Context) (*S, error) {
return Start(ctx, func(fCtx context.Context) (S, error) {
result, err := tsk.Result(fCtx)
if err != nil {
return nil, err
return *new(S), err
}
return next(fCtx, result)
})
}
// ContinueActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, *T) error) func(context.Context, *T) (*interface{}, error) {
return func(ctx context.Context, t *T) (*interface{}, error) {
//
// Action is function that runs without return anything
// Func is function that runs and return something
func ContinueActionToFunc[T any](action func(context.Context, T) error) func(context.Context, T) (interface{}, error) {
return func(ctx context.Context, t T) (interface{}, error) {
return nil, action(ctx, t)
}
}

Просмотреть файл

@ -11,7 +11,7 @@ import (
)
func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)
result := countFrom
@ -22,10 +22,10 @@ func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duratio
result++
case <-ctx.Done():
t.Log("work canceled")
return &result, nil
return result, nil
}
}
return &result, nil
return result, nil
}
}
@ -33,36 +33,36 @@ func TestContinueWith(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})
result, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
assert.Equal(t, *result, 20)
assert.Equal(t, result, 20)
result, err = t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, *result, 22)
assert.Equal(t, result, 22)
}
func TestContinueWithFailureCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input int) (int, error) {
fromPrevTsk := input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})

Просмотреть файл

@ -11,16 +11,16 @@ import (
)
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
return func(ctx context.Context) (*string, error) {
return func(ctx context.Context) (string, error) {
time.Sleep(sleepDuration)
panic("yo")
}
}
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
time.Sleep(sleepDuration)
return nil, errors.New(errorString)
return 0, errors.New(errorString)
}
}
@ -37,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
// I can continue wait with longer time
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)
// any following Wait should complete immediately
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
assert.NoError(t, err)
assert.Equal(t, 9, *rawResult)
assert.Equal(t, 9, rawResult)
}
func TestPanicCase(t *testing.T) {

2
go.mod
Просмотреть файл

@ -1,6 +1,6 @@
module github.com/Azure/go-asynctask
go 1.19
go 1.20
require github.com/stretchr/testify v1.8.4

41
task.go
Просмотреть файл

@ -9,13 +9,13 @@ import (
)
// AsyncFunc is a function interface this asyncTask accepts.
type AsyncFunc[T any] func(context.Context) (*T, error)
type AsyncFunc[T any] func(context.Context) (T, error)
// ActionToFunc convert a Action to Func (C# term), to satisfy the AsyncFunc interface.
// - Action is function that runs without return anything
// - Func is function that runs and return something
func ActionToFunc(action func(context.Context) error) func(context.Context) (*interface{}, error) {
return func(ctx context.Context) (*interface{}, error) {
func ActionToFunc(action func(context.Context) error) func(context.Context) (interface{}, error) {
return func(ctx context.Context) (interface{}, error) {
return nil, action(ctx)
}
}
@ -24,17 +24,17 @@ func ActionToFunc(action func(context.Context) error) func(context.Context) (*in
// which you can use to wait, cancel, get the result.
type Task[T any] struct {
state State
result *T
result T
err error
cancelFunc context.CancelFunc
waitGroup *sync.WaitGroup
mutex *sync.Mutex
mutex *sync.RWMutex
}
// State return state of the task.
func (t *Task[T]) State() State {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state
}
@ -42,7 +42,7 @@ func (t *Task[T]) State() State {
// !! this rely on the task function to check context cancellation and proper context handling.
func (t *Task[T]) Cancel() bool {
if !t.finished() {
t.finish(StateCanceled, nil, ErrCanceled)
t.finish(StateCanceled, *new(T), ErrCanceled)
return true
}
@ -74,7 +74,7 @@ func (t *Task[T]) Wait(ctx context.Context) error {
// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
// timeout only stop waiting, taks will remain running.
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (T, error) {
// return immediately if task already in terminal state.
if t.finished() {
return t.result, t.err
@ -86,11 +86,10 @@ func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*
return t.Result(ctx)
}
func (t *Task[T]) Result(ctx context.Context) (*T, error) {
func (t *Task[T]) Result(ctx context.Context) (T, error) {
err := t.Wait(ctx)
if err != nil {
var result T
return &result, err
return *new(T), err
}
return t.result, t.err
@ -102,11 +101,11 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
ctx, cancel := context.WithCancel(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
mutex := &sync.Mutex{}
mutex := &sync.RWMutex{}
record := &Task[T]{
state: StateRunning,
result: nil,
result: *new(T),
cancelFunc: cancel,
waitGroup: wg,
mutex: mutex,
@ -118,7 +117,7 @@ func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
}
// NewCompletedTask returns a Completed task, with result=nil, error=nil
func NewCompletedTask[T any](value *T) *Task[T] {
func NewCompletedTask[T any](value T) *Task[T] {
return &Task[T]{
state: StateCompleted,
result: value,
@ -126,16 +125,16 @@ func NewCompletedTask[T any](value *T) *Task[T] {
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
cancelFunc: nil,
waitGroup: nil,
mutex: &sync.Mutex{},
mutex: &sync.RWMutex{},
}
}
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (T, error)) {
defer record.waitGroup.Done()
defer func() {
if r := recover(); r != nil {
err := fmt.Errorf("panic cought: %v, stackTrace: %s, %w", r, debug.Stack(), ErrPanic)
record.finish(StateFailed, nil, err)
record.finish(StateFailed, *new(T), err)
}
}()
@ -150,7 +149,7 @@ func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task fu
record.finish(StateFailed, result, err)
}
func (t *Task[T]) finish(state State, result *T, err error) {
func (t *Task[T]) finish(state State, result T, err error) {
// only update state and result if not yet canceled
t.mutex.Lock()
defer t.mutex.Unlock()
@ -163,7 +162,7 @@ func (t *Task[T]) finish(state State, result *T, err error) {
}
func (t *Task[T]) finished() bool {
t.mutex.Lock()
defer t.mutex.Unlock()
t.mutex.RLock()
defer t.mutex.RUnlock()
return t.state.IsTerminalState()
}

Просмотреть файл

@ -22,7 +22,7 @@ func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Con
}
func getCountingTask(countTo int, taskId string, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
t := ctx.Value(testContextKey).(*testing.T)
result := 0
@ -35,10 +35,10 @@ func getCountingTask(countTo int, taskId string, sleepInterval time.Duration) as
// testing.Logf would cause DataRace error when test is already finished: https://github.com/golang/go/issues/40343
// leave minor time buffer before exit test to finish this last logging at least.
t.Logf("[%s]: work canceled", taskId)
return &result, nil
return result, nil
}
}
return &result, nil
return result, nil
}
}
@ -54,7 +54,7 @@ func TestEasyGenericCase(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
assert.Equal(t, *rawResult, 9)
assert.Equal(t, rawResult, 9)
// wait Again,
start := time.Now()
@ -64,7 +64,7 @@ func TestEasyGenericCase(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
assert.Equal(t, *rawResult, 9)
assert.Equal(t, rawResult, 9)
// Result should be returned immediately
assert.True(t, elapsed.Milliseconds() < 1, fmt.Sprintf("Second wait should have return immediately: %s", elapsed))
@ -121,13 +121,13 @@ func TestConsistentResultAfterCancelGenericTask(t *testing.T) {
rawResult, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "t2 should complete")
assert.Equal(t, *rawResult, 9)
assert.Equal(t, rawResult, 9)
// t1 should remain canceled and
rawResult, err = t1.Result(ctx)
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Equal(t, *rawResult, 0) // default value for int
assert.Equal(t, rawResult, 0) // default value for int
}
func TestCompletedGenericTask(t *testing.T) {
@ -187,10 +187,10 @@ func TestCrazyCaseGeneric(t *testing.T) {
if i%2 == 0 {
assert.Equal(t, asynctask.ErrCanceled, err, fmt.Sprintf("task %s should be canceled, but it finished with %+v", fmt.Sprintf("CrazyTask%d", i), rawResult))
assert.Equal(t, *rawResult, 0)
assert.Equal(t, rawResult, 0)
} else {
assert.NoError(t, err)
assert.Equal(t, *rawResult, 9)
assert.Equal(t, rawResult, 9)
}
}
}

Просмотреть файл

@ -186,7 +186,7 @@ func TestWaitAllWithNoTasks(t *testing.T) {
// getUncontrollableTask return a task that is not honor context, it only hornor the remoteControl context.
func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
return func(ctx context.Context) (int, error) {
for {
select {
case <-time.After(1 * time.Millisecond):
@ -195,7 +195,7 @@ func getUncontrollableTask(rcCtx context.Context, t *testing.T) asynctask.AsyncF
}
case <-rcCtx.Done():
t.Logf("[UncontrollableTask]: cancelled by remote control")
return nil, rcCtx.Err()
return 0, rcCtx.Err()
}
}
}