use generics for asyncTasks, no type assertion required (#14)

* tweaks

* add test for generic version

* continueFunc in generic

* somewhat working

* update to go 1.18
This commit is contained in:
Haitao Chen 2022-05-31 13:49:21 -07:00 коммит произвёл GitHub
Родитель 55c33d025d
Коммит 22f82519b0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 164 добавлений и 232 удалений

12
.github/workflows/go.yml поставляемый
Просмотреть файл

@ -13,23 +13,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.13
- name: Set up Go 1.18
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.18
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Get dependencies
run: |
go get -v -t -d ./...
if [ -f Gopkg.toml ]; then
curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
dep ensure
fi
- name: Build
run: go build -v .

Просмотреть файл

@ -3,13 +3,11 @@ package asynctask
import "context"
// ContinueFunc is a function that can be connected to previous task with ContinueWith
type ContinueFunc func(context.Context, interface{}) (interface{}, error)
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
// ContinueWith start the function when current task is done.
// result from previous task will be passed in, if no error.
func (tsk *TaskStatus) ContinueWith(ctx context.Context, next ContinueFunc) *TaskStatus {
return Start(ctx, func(fCtx context.Context) (interface{}, error) {
result, err := tsk.Wait(fCtx)
func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
return Start(ctx, func(fCtx context.Context) (*S, error) {
result, err := tsk.Result(fCtx)
if err != nil {
return nil, err
}

Просмотреть файл

@ -9,8 +9,8 @@ import (
"github.com/stretchr/testify/assert"
)
func getAdvancedCountingTask(countFrom, step int, sleepInterval time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
t := ctx.Value(testContextKey).(*testing.T)
result := countFrom
@ -21,10 +21,10 @@ func getAdvancedCountingTask(countFrom, step int, sleepInterval time.Duration) a
result++
case <-ctx.Done():
t.Log("work canceled")
return result, nil
return &result, nil
}
}
return result, nil
return &result, nil
}
}
@ -32,45 +32,45 @@ func TestContinueWith(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
t2 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
fromPrevTsk := input.(int)
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
fromPrevTsk := input.(int)
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})
result, err := t2.Wait(ctx)
result, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
assert.Equal(t, result, 20)
assert.Equal(t, *result, 20)
result, err = t3.Wait(ctx)
result, err = t3.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
assert.Equal(t, result, 22)
assert.Equal(t, *result, 22)
}
func TestContinueWithFailureCase(t *testing.T) {
t.Parallel()
ctx := newTestContext(t)
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
t2 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
fromPrevTsk := input.(int)
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
})
t3 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
fromPrevTsk := input.(int)
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
fromPrevTsk := *input
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
})
_, err := t2.Wait(ctx)
_, err := t2.Result(ctx)
assert.Error(t, err)
assert.Equal(t, asynctask.StateFailed, t2.State(), "Task2 should fail since preceeding task failed")
assert.Equal(t, "devide by 0", err.Error())
_, err = t3.Wait(ctx)
_, err = t3.Result(ctx)
assert.Error(t, err)
assert.Equal(t, asynctask.StateFailed, t3.State(), "Task3 should fail since preceeding task failed")
assert.Equal(t, "devide by 0", err.Error())

Просмотреть файл

@ -10,27 +10,15 @@ import (
"github.com/stretchr/testify/assert"
)
type structError struct{}
func (pe structError) Error() string {
return "Error from struct type"
}
type pointerError struct{}
func (pe *pointerError) Error() string {
return "Error from pointer type"
}
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
return func(ctx context.Context) (*string, error) {
time.Sleep(sleepDuration)
panic("yo")
}
}
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
time.Sleep(sleepDuration)
return nil, errors.New(errorString)
}
@ -49,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
// I can continue wait with longer time
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
assert.NoError(t, err)
assert.Equal(t, 9, rawResult)
assert.Equal(t, 9, *rawResult)
// any following Wait should complete immediately
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
assert.NoError(t, err)
assert.Equal(t, 9, rawResult)
assert.Equal(t, 9, *rawResult)
}
func TestPanicCase(t *testing.T) {
@ -79,49 +67,3 @@ func TestErrorCase(t *testing.T) {
assert.False(t, errors.Is(err, context.DeadlineExceeded), "not expecting DeadlineExceeded")
assert.Equal(t, "dummy error", err.Error())
}
func TestPointerErrorCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
// nil point of a type that implement error
var pe *pointerError = nil
// pass this nil pointer to error interface
var err error = pe
// now you get a non-nil error
assert.False(t, err == nil, "reason this test is needed")
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
time.Sleep(100 * time.Millisecond)
var pe *pointerError = nil
return "Done", pe
})
result, err := tsk.Wait(ctx)
assert.NoError(t, err)
assert.Equal(t, result, "Done")
}
func TestStructErrorCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
// nil point of a type that implement error
var se structError
// pass this nil pointer to error interface
var err error = se
// now you get a non-nil error
assert.False(t, err == nil, "reason this test is needed")
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
time.Sleep(100 * time.Millisecond)
var se structError
return "Done", se
})
result, err := tsk.Wait(ctx)
assert.NoError(t, err)
assert.Equal(t, result, "Done")
}

10
go.mod
Просмотреть файл

@ -1,5 +1,11 @@
module github.com/Azure/go-asynctask
go 1.13
go 1.18
require github.com/stretchr/testify v1.7.0
require github.com/stretchr/testify v1.7.1
require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

5
go.sum
Просмотреть файл

@ -2,10 +2,9 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=

Просмотреть файл

@ -2,75 +2,48 @@ package asynctask
import (
"context"
"errors"
"fmt"
"reflect"
"runtime/debug"
"sync"
"time"
)
// State of a task.
type State string
// StateRunning indicate task is still running.
const StateRunning State = "Running"
// StateCompleted indicate task is finished.
const StateCompleted State = "Completed"
// StateFailed indicate task failed.
const StateFailed State = "Failed"
// StateCanceled indicate task got canceled.
const StateCanceled State = "Canceled"
// IsTerminalState tells whether the task finished
func (s State) IsTerminalState() bool {
return s != StateRunning
}
// AsyncFunc is a function interface this asyncTask accepts.
type AsyncFunc func(context.Context) (interface{}, error)
type AsyncFunc[T any] func(context.Context) (*T, error)
// ErrPanic is returned if panic cought in the task
var ErrPanic = errors.New("panic")
// ErrCanceled is returned if a cancel is triggered
var ErrCanceled = errors.New("canceled")
// TaskStatus is a handle to the running function.
// Task is a handle to the running function.
// which you can use to wait, cancel, get the result.
type TaskStatus struct {
type Task[T any] struct {
state State
result interface{}
result *T
err error
cancelFunc context.CancelFunc
waitGroup *sync.WaitGroup
}
// State return state of the task.
func (t *TaskStatus) State() State {
func (t *Task[T]) State() State {
return t.state
}
// Cancel abort the task execution
// !! only if the function provided handles context cancel.
func (t *TaskStatus) Cancel() {
func (t *Task[T]) Cancel() {
if !t.state.IsTerminalState() {
t.cancelFunc()
t.finish(StateCanceled, nil, ErrCanceled)
var result T
t.finish(StateCanceled, &result, ErrCanceled)
}
}
// Wait block current thread/routine until task finished or failed.
// context passed in can terminate the wait, through context cancellation
// but won't terminate the task (unless it's same context)
func (t *TaskStatus) Wait(ctx context.Context) (interface{}, error) {
func (t *Task[T]) Wait(ctx context.Context) error {
// return immediately if task already in terminal state.
if t.state.IsTerminalState() {
return t.result, t.err
return t.err
}
ch := make(chan interface{})
@ -81,15 +54,15 @@ func (t *TaskStatus) Wait(ctx context.Context) (interface{}, error) {
select {
case <-ch:
return t.result, t.err
return t.err
case <-ctx.Done():
return nil, ctx.Err()
return ctx.Err()
}
}
// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
// timeout only stop waiting, taks will remain running.
func (t *TaskStatus) WaitWithTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
// return immediately if task already in terminal state.
if t.state.IsTerminalState() {
return t.result, t.err
@ -98,14 +71,44 @@ func (t *TaskStatus) WaitWithTimeout(ctx context.Context, timeout time.Duration)
ctx, cancelFunc := context.WithTimeout(ctx, timeout)
defer cancelFunc()
return t.Wait(ctx)
return t.Result(ctx)
}
func (t *Task[T]) Result(ctx context.Context) (*T, error) {
err := t.Wait(ctx)
if err != nil {
var result T
return &result, err
}
return t.result, t.err
}
// Start run a async function and returns you a handle which you can Wait or Cancel.
// context passed in may impact task lifetime (from context cancellation)
func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
ctx, cancel := context.WithCancel(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
var result T
record := &Task[T]{
state: StateRunning,
result: &result,
cancelFunc: cancel,
waitGroup: wg,
}
go runAndTrackGenericTask(ctx, record, task)
return record
}
// NewCompletedTask returns a Completed task, with result=nil, error=nil
func NewCompletedTask() *TaskStatus {
return &TaskStatus{
func NewCompletedTask[T any](value *T) *Task[T] {
return &Task[T]{
state: StateCompleted,
result: nil,
result: value,
err: nil,
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
cancelFunc: nil,
@ -113,42 +116,7 @@ func NewCompletedTask() *TaskStatus {
}
}
// Start run a async function and returns you a handle which you can Wait or Cancel.
// context passed in may impact task lifetime (from context cancellation)
func Start(ctx context.Context, task AsyncFunc) *TaskStatus {
ctx, cancel := context.WithCancel(ctx)
wg := &sync.WaitGroup{}
wg.Add(1)
record := &TaskStatus{
state: StateRunning,
result: nil,
cancelFunc: cancel,
waitGroup: wg,
}
go runAndTrackTask(ctx, record, task)
return record
}
// isErrorReallyError do extra error check
// - Nil Pointer to a Type (that implement error)
// - Zero Value of a Type (that implement error)
func isErrorReallyError(err error) bool {
v := reflect.ValueOf(err)
if v.Type().Kind() == reflect.Ptr &&
v.IsNil() {
return false
}
if v.Type().Kind() == reflect.Struct &&
v.IsZero() {
return false
}
return true
}
func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx context.Context) (interface{}, error)) {
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
defer record.waitGroup.Done()
defer func() {
if r := recover(); r != nil {
@ -159,11 +127,7 @@ func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx cont
result, err := task(ctx)
if err == nil ||
// incase some team use pointer typed error (implement Error() string on a pointer type)
// which can break err check (but nil point assigned to error result to non-nil error)
// check out TestPointerErrorCase in error_test.go
!isErrorReallyError(err) {
if err == nil {
record.finish(StateCompleted, result, nil)
return
}
@ -172,7 +136,7 @@ func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx cont
record.finish(StateFailed, result, err)
}
func (t *TaskStatus) finish(state State, result interface{}, err error) {
func (t *Task[T]) finish(state State, result *T, err error) {
// only update state and result if not yet canceled
if !t.state.IsTerminalState() {
t.state = state

Просмотреть файл

@ -9,9 +9,7 @@ import (
"github.com/stretchr/testify/assert"
)
type notMatter string
const testContextKey notMatter = "testing"
const testContextKey string = "testing"
func newTestContext(t *testing.T) context.Context {
return context.WithValue(context.TODO(), testContextKey, t)
@ -21,8 +19,8 @@ func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Con
return context.WithTimeout(context.WithValue(context.TODO(), testContextKey, t), timeout)
}
func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc {
return func(ctx context.Context) (interface{}, error) {
func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
return func(ctx context.Context) (*int, error) {
t := ctx.Value(testContextKey).(*testing.T)
result := 0
@ -33,14 +31,14 @@ func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFu
result = i
case <-ctx.Done():
t.Log("work canceled")
return result, nil
return &result, nil
}
}
return result, nil
return &result, nil
}
}
func TestEasyCase(t *testing.T) {
func TestEasyGenericCase(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
@ -48,28 +46,26 @@ func TestEasyCase(t *testing.T) {
t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")
rawResult, err := t1.Wait(ctx)
rawResult, err := t1.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
result := rawResult.(int)
assert.Equal(t, result, 9)
assert.Equal(t, *rawResult, 9)
// wait Again,
start := time.Now()
rawResult, err = t1.Wait(ctx)
rawResult, err = t1.Result(ctx)
elapsed := time.Since(start)
// nothing should change
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
assert.NotNil(t, rawResult)
result = rawResult.(int)
assert.Equal(t, result, 9)
assert.Equal(t, *rawResult, 9)
assert.True(t, elapsed.Microseconds() < 3, "Second wait should return immediately")
}
func TestCancelFunc(t *testing.T) {
func TestCancelFuncOnGeneric(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
@ -80,18 +76,16 @@ func TestCancelFunc(t *testing.T) {
time.Sleep(time.Second * 1)
t1.Cancel()
rawResult, err := t1.Wait(ctx)
_, err := t1.Result(ctx)
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Nil(t, rawResult)
// I can cancel again, and nothing changes
time.Sleep(time.Second * 1)
t1.Cancel()
rawResult, err = t1.Wait(ctx)
_, err = t1.Result(ctx)
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Nil(t, rawResult)
// cancel a task shouldn't cancel it's parent context.
select {
@ -102,7 +96,7 @@ func TestCancelFunc(t *testing.T) {
}
}
func TestConsistentResultAfterCancel(t *testing.T) {
func TestConsistentResultAfterCancelGenericTask(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
@ -119,24 +113,25 @@ func TestConsistentResultAfterCancel(t *testing.T) {
assert.True(t, duration < 1*time.Millisecond, "cancel shouldn't take that long")
// wait til routine finish
rawResult, err := t2.Wait(ctx)
rawResult, err := t2.Result(ctx)
assert.NoError(t, err)
assert.Equal(t, asynctask.StateCompleted, t2.State(), "t2 should complete")
assert.Equal(t, rawResult, 9)
assert.Equal(t, *rawResult, 9)
// t1 should remain canceled and
rawResult, err = t1.Wait(ctx)
rawResult, err = t1.Result(ctx)
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
assert.Nil(t, rawResult)
assert.Equal(t, *rawResult, 0) // default value for int
}
func TestCompletedTask(t *testing.T) {
func TestCompletedGenericTask(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
tsk := asynctask.NewCompletedTask()
result := "something"
tsk := asynctask.NewCompletedTask(&result)
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should in CompletedState")
// nothing should happen
@ -144,19 +139,19 @@ func TestCompletedTask(t *testing.T) {
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
// you get nil result and nil error
result, err := tsk.Wait(ctx)
resultGet, err := tsk.Result(ctx)
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
assert.NoError(t, err)
assert.Nil(t, result)
assert.Equal(t, *resultGet, result)
}
func TestCrazyCase(t *testing.T) {
func TestCrazyCaseGeneric(t *testing.T) {
t.Parallel()
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
defer cancelFunc()
numOfTasks := 8000 // if you have --race switch on: limit on 8128 simultaneously alive goroutines is exceeded, dying
tasks := map[int]*asynctask.TaskStatus{}
tasks := map[int]*asynctask.Task[int]{}
for i := 0; i < numOfTasks; i++ {
tasks[i] = asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
}
@ -167,17 +162,14 @@ func TestCrazyCase(t *testing.T) {
}
for i := 0; i < numOfTasks; i += 1 {
rawResult, err := tasks[i].Wait(ctx)
rawResult, err := tasks[i].Result(ctx)
if i%2 == 0 {
assert.Equal(t, asynctask.ErrCanceled, err, "should be canceled")
assert.Nil(t, rawResult)
assert.Equal(t, *rawResult, 0)
} else {
assert.NoError(t, err)
assert.NotNil(t, rawResult)
result := rawResult.(int)
assert.Equal(t, result, 9)
assert.Equal(t, *rawResult, 9)
}
}
}

31
types.go Normal file
Просмотреть файл

@ -0,0 +1,31 @@
package asynctask
import (
"errors"
)
// State of a task.
type State string
// StateRunning indicate task is still running.
const StateRunning State = "Running"
// StateCompleted indicate task is finished.
const StateCompleted State = "Completed"
// StateFailed indicate task failed.
const StateFailed State = "Failed"
// StateCanceled indicate task got canceled.
const StateCanceled State = "Canceled"
// IsTerminalState tells whether the task finished
func (s State) IsTerminalState() bool {
return s != StateRunning
}
// ErrPanic is returned if panic cought in the task
var ErrPanic = errors.New("panic")
// ErrCanceled is returned if a cancel is triggered
var ErrCanceled = errors.New("canceled")

Просмотреть файл

@ -6,6 +6,10 @@ import (
"sync"
)
type Waitable interface {
Wait(context.Context) error
}
// WaitAllOptions defines options for WaitAll function
type WaitAllOptions struct {
// FailFast set to true will indicate WaitAll to return on first error it sees.
@ -14,7 +18,7 @@ type WaitAllOptions struct {
// WaitAll block current thread til all task finished.
// first error from any tasks passed in will be returned.
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus) error {
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...Waitable) error {
tasksCount := len(tasks)
mutex := sync.Mutex{}
@ -65,8 +69,8 @@ func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus)
return nil
}
func waitOne(ctx context.Context, tsk *TaskStatus, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
_, err := tsk.Wait(ctx)
func waitOne(ctx context.Context, tsk Waitable, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
err := tsk.Wait(ctx)
// why mutex?
// if all tasks start using same context (unittest is good example)

Просмотреть файл

@ -18,7 +18,8 @@ func TestWaitAll(t *testing.T) {
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()
result := "something"
completedTsk := asynctask.NewCompletedTask(&result)
start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
@ -36,7 +37,8 @@ func TestWaitAllFailFastCase(t *testing.T) {
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()
result := "something"
completedTsk := asynctask.NewCompletedTask(&result)
start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk)
@ -61,7 +63,8 @@ func TestWaitAllErrorCase(t *testing.T) {
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()
result := "something"
completedTsk := asynctask.NewCompletedTask(&result)
start := time.Now()
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk)
@ -86,7 +89,8 @@ func TestWaitAllCanceled(t *testing.T) {
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
completedTsk := asynctask.NewCompletedTask()
result := "something"
completedTsk := asynctask.NewCompletedTask(&result)
waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond)
defer cancelFunc1()