зеркало из https://github.com/Azure/go-asynctask.git
use generics for asyncTasks, no type assertion required (#14)
* tweaks * add test for generic version * continueFunc in generic * somewhat working * update to go 1.18
This commit is contained in:
Родитель
55c33d025d
Коммит
22f82519b0
|
@ -13,23 +13,15 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.13
|
||||
- name: Set up Go 1.18
|
||||
uses: actions/setup-go@v1
|
||||
with:
|
||||
go-version: 1.13
|
||||
go-version: 1.18
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get -v -t -d ./...
|
||||
if [ -f Gopkg.toml ]; then
|
||||
curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
|
||||
dep ensure
|
||||
fi
|
||||
|
||||
- name: Build
|
||||
run: go build -v .
|
||||
|
||||
|
|
|
@ -3,13 +3,11 @@ package asynctask
|
|||
import "context"
|
||||
|
||||
// ContinueFunc is a function that can be connected to previous task with ContinueWith
|
||||
type ContinueFunc func(context.Context, interface{}) (interface{}, error)
|
||||
type ContinueFunc[T any, S any] func(context.Context, *T) (*S, error)
|
||||
|
||||
// ContinueWith start the function when current task is done.
|
||||
// result from previous task will be passed in, if no error.
|
||||
func (tsk *TaskStatus) ContinueWith(ctx context.Context, next ContinueFunc) *TaskStatus {
|
||||
return Start(ctx, func(fCtx context.Context) (interface{}, error) {
|
||||
result, err := tsk.Wait(fCtx)
|
||||
func ContinueWith[T any, S any](ctx context.Context, tsk *Task[T], next ContinueFunc[T, S]) *Task[S] {
|
||||
return Start(ctx, func(fCtx context.Context) (*S, error) {
|
||||
result, err := tsk.Result(fCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func getAdvancedCountingTask(countFrom, step int, sleepInterval time.Duration) asynctask.AsyncFunc {
|
||||
return func(ctx context.Context) (interface{}, error) {
|
||||
func getAdvancedCountingTask(countFrom int, step int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
|
||||
return func(ctx context.Context) (*int, error) {
|
||||
t := ctx.Value(testContextKey).(*testing.T)
|
||||
|
||||
result := countFrom
|
||||
|
@ -21,10 +21,10 @@ func getAdvancedCountingTask(countFrom, step int, sleepInterval time.Duration) a
|
|||
result++
|
||||
case <-ctx.Done():
|
||||
t.Log("work canceled")
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,45 +32,45 @@ func TestContinueWith(t *testing.T) {
|
|||
t.Parallel()
|
||||
ctx := newTestContext(t)
|
||||
t1 := asynctask.Start(ctx, getAdvancedCountingTask(0, 10, 20*time.Millisecond))
|
||||
t2 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
|
||||
fromPrevTsk := input.(int)
|
||||
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
|
||||
fromPrevTsk := *input
|
||||
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
|
||||
})
|
||||
t3 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
|
||||
fromPrevTsk := input.(int)
|
||||
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
|
||||
fromPrevTsk := *input
|
||||
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
|
||||
})
|
||||
|
||||
result, err := t2.Wait(ctx)
|
||||
result, err := t2.Result(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, asynctask.StateCompleted, t2.State(), "Task should complete with no error")
|
||||
assert.Equal(t, result, 20)
|
||||
assert.Equal(t, *result, 20)
|
||||
|
||||
result, err = t3.Wait(ctx)
|
||||
result, err = t3.Result(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, asynctask.StateCompleted, t3.State(), "Task should complete with no error")
|
||||
assert.Equal(t, result, 22)
|
||||
assert.Equal(t, *result, 22)
|
||||
}
|
||||
|
||||
func TestContinueWithFailureCase(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := newTestContext(t)
|
||||
t1 := asynctask.Start(ctx, getErrorTask("devide by 0", 10*time.Millisecond))
|
||||
t2 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
|
||||
fromPrevTsk := input.(int)
|
||||
t2 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
|
||||
fromPrevTsk := *input
|
||||
return getAdvancedCountingTask(fromPrevTsk, 10, 20*time.Millisecond)(fCtx)
|
||||
})
|
||||
t3 := t1.ContinueWith(ctx, func(fCtx context.Context, input interface{}) (interface{}, error) {
|
||||
fromPrevTsk := input.(int)
|
||||
t3 := asynctask.ContinueWith(ctx, t1, func(fCtx context.Context, input *int) (*int, error) {
|
||||
fromPrevTsk := *input
|
||||
return getAdvancedCountingTask(fromPrevTsk, 12, 20*time.Millisecond)(fCtx)
|
||||
})
|
||||
|
||||
_, err := t2.Wait(ctx)
|
||||
_, err := t2.Result(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, asynctask.StateFailed, t2.State(), "Task2 should fail since preceeding task failed")
|
||||
assert.Equal(t, "devide by 0", err.Error())
|
||||
|
||||
_, err = t3.Wait(ctx)
|
||||
_, err = t3.Result(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, asynctask.StateFailed, t3.State(), "Task3 should fail since preceeding task failed")
|
||||
assert.Equal(t, "devide by 0", err.Error())
|
||||
|
|
|
@ -10,27 +10,15 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type structError struct{}
|
||||
|
||||
func (pe structError) Error() string {
|
||||
return "Error from struct type"
|
||||
}
|
||||
|
||||
type pointerError struct{}
|
||||
|
||||
func (pe *pointerError) Error() string {
|
||||
return "Error from pointer type"
|
||||
}
|
||||
|
||||
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc {
|
||||
return func(ctx context.Context) (interface{}, error) {
|
||||
func getPanicTask(sleepDuration time.Duration) asynctask.AsyncFunc[string] {
|
||||
return func(ctx context.Context) (*string, error) {
|
||||
time.Sleep(sleepDuration)
|
||||
panic("yo")
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc {
|
||||
return func(ctx context.Context) (interface{}, error) {
|
||||
func getErrorTask(errorString string, sleepDuration time.Duration) asynctask.AsyncFunc[int] {
|
||||
return func(ctx context.Context) (*int, error) {
|
||||
time.Sleep(sleepDuration)
|
||||
return nil, errors.New(errorString)
|
||||
}
|
||||
|
@ -49,12 +37,12 @@ func TestTimeoutCase(t *testing.T) {
|
|||
// I can continue wait with longer time
|
||||
rawResult, err := tsk.WaitWithTimeout(ctx, 2*time.Second)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 9, rawResult)
|
||||
assert.Equal(t, 9, *rawResult)
|
||||
|
||||
// any following Wait should complete immediately
|
||||
rawResult, err = tsk.WaitWithTimeout(ctx, 2*time.Nanosecond)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 9, rawResult)
|
||||
assert.Equal(t, 9, *rawResult)
|
||||
}
|
||||
|
||||
func TestPanicCase(t *testing.T) {
|
||||
|
@ -79,49 +67,3 @@ func TestErrorCase(t *testing.T) {
|
|||
assert.False(t, errors.Is(err, context.DeadlineExceeded), "not expecting DeadlineExceeded")
|
||||
assert.Equal(t, "dummy error", err.Error())
|
||||
}
|
||||
|
||||
func TestPointerErrorCase(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
// nil point of a type that implement error
|
||||
var pe *pointerError = nil
|
||||
// pass this nil pointer to error interface
|
||||
var err error = pe
|
||||
// now you get a non-nil error
|
||||
assert.False(t, err == nil, "reason this test is needed")
|
||||
|
||||
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
var pe *pointerError = nil
|
||||
return "Done", pe
|
||||
})
|
||||
|
||||
result, err := tsk.Wait(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, "Done")
|
||||
}
|
||||
|
||||
func TestStructErrorCase(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
// nil point of a type that implement error
|
||||
var se structError
|
||||
// pass this nil pointer to error interface
|
||||
var err error = se
|
||||
// now you get a non-nil error
|
||||
assert.False(t, err == nil, "reason this test is needed")
|
||||
|
||||
tsk := asynctask.Start(ctx, func(ctx context.Context) (interface{}, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
var se structError
|
||||
return "Done", se
|
||||
})
|
||||
|
||||
result, err := tsk.Wait(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result, "Done")
|
||||
}
|
||||
|
|
10
go.mod
10
go.mod
|
@ -1,5 +1,11 @@
|
|||
module github.com/Azure/go-asynctask
|
||||
|
||||
go 1.13
|
||||
go 1.18
|
||||
|
||||
require github.com/stretchr/testify v1.7.0
|
||||
require github.com/stretchr/testify v1.7.1
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
|
||||
)
|
||||
|
|
5
go.sum
5
go.sum
|
@ -2,10 +2,9 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
|
|
|
@ -2,75 +2,48 @@ package asynctask
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// State of a task.
|
||||
type State string
|
||||
|
||||
// StateRunning indicate task is still running.
|
||||
const StateRunning State = "Running"
|
||||
|
||||
// StateCompleted indicate task is finished.
|
||||
const StateCompleted State = "Completed"
|
||||
|
||||
// StateFailed indicate task failed.
|
||||
const StateFailed State = "Failed"
|
||||
|
||||
// StateCanceled indicate task got canceled.
|
||||
const StateCanceled State = "Canceled"
|
||||
|
||||
// IsTerminalState tells whether the task finished
|
||||
func (s State) IsTerminalState() bool {
|
||||
return s != StateRunning
|
||||
}
|
||||
|
||||
// AsyncFunc is a function interface this asyncTask accepts.
|
||||
type AsyncFunc func(context.Context) (interface{}, error)
|
||||
type AsyncFunc[T any] func(context.Context) (*T, error)
|
||||
|
||||
// ErrPanic is returned if panic cought in the task
|
||||
var ErrPanic = errors.New("panic")
|
||||
|
||||
// ErrCanceled is returned if a cancel is triggered
|
||||
var ErrCanceled = errors.New("canceled")
|
||||
|
||||
// TaskStatus is a handle to the running function.
|
||||
// Task is a handle to the running function.
|
||||
// which you can use to wait, cancel, get the result.
|
||||
type TaskStatus struct {
|
||||
type Task[T any] struct {
|
||||
state State
|
||||
result interface{}
|
||||
result *T
|
||||
err error
|
||||
cancelFunc context.CancelFunc
|
||||
waitGroup *sync.WaitGroup
|
||||
}
|
||||
|
||||
// State return state of the task.
|
||||
func (t *TaskStatus) State() State {
|
||||
func (t *Task[T]) State() State {
|
||||
return t.state
|
||||
}
|
||||
|
||||
// Cancel abort the task execution
|
||||
// !! only if the function provided handles context cancel.
|
||||
func (t *TaskStatus) Cancel() {
|
||||
func (t *Task[T]) Cancel() {
|
||||
if !t.state.IsTerminalState() {
|
||||
t.cancelFunc()
|
||||
|
||||
t.finish(StateCanceled, nil, ErrCanceled)
|
||||
var result T
|
||||
t.finish(StateCanceled, &result, ErrCanceled)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait block current thread/routine until task finished or failed.
|
||||
// context passed in can terminate the wait, through context cancellation
|
||||
// but won't terminate the task (unless it's same context)
|
||||
func (t *TaskStatus) Wait(ctx context.Context) (interface{}, error) {
|
||||
func (t *Task[T]) Wait(ctx context.Context) error {
|
||||
// return immediately if task already in terminal state.
|
||||
if t.state.IsTerminalState() {
|
||||
return t.result, t.err
|
||||
return t.err
|
||||
}
|
||||
|
||||
ch := make(chan interface{})
|
||||
|
@ -81,15 +54,15 @@ func (t *TaskStatus) Wait(ctx context.Context) (interface{}, error) {
|
|||
|
||||
select {
|
||||
case <-ch:
|
||||
return t.result, t.err
|
||||
return t.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// WaitWithTimeout block current thread/routine until task finished or failed, or exceed the duration specified.
|
||||
// timeout only stop waiting, taks will remain running.
|
||||
func (t *TaskStatus) WaitWithTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
|
||||
func (t *Task[T]) WaitWithTimeout(ctx context.Context, timeout time.Duration) (*T, error) {
|
||||
// return immediately if task already in terminal state.
|
||||
if t.state.IsTerminalState() {
|
||||
return t.result, t.err
|
||||
|
@ -98,14 +71,44 @@ func (t *TaskStatus) WaitWithTimeout(ctx context.Context, timeout time.Duration)
|
|||
ctx, cancelFunc := context.WithTimeout(ctx, timeout)
|
||||
defer cancelFunc()
|
||||
|
||||
return t.Wait(ctx)
|
||||
return t.Result(ctx)
|
||||
}
|
||||
|
||||
func (t *Task[T]) Result(ctx context.Context) (*T, error) {
|
||||
err := t.Wait(ctx)
|
||||
if err != nil {
|
||||
var result T
|
||||
return &result, err
|
||||
}
|
||||
|
||||
return t.result, t.err
|
||||
}
|
||||
|
||||
// Start run a async function and returns you a handle which you can Wait or Cancel.
|
||||
// context passed in may impact task lifetime (from context cancellation)
|
||||
func Start[T any](ctx context.Context, task AsyncFunc[T]) *Task[T] {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
var result T
|
||||
record := &Task[T]{
|
||||
state: StateRunning,
|
||||
result: &result,
|
||||
cancelFunc: cancel,
|
||||
waitGroup: wg,
|
||||
}
|
||||
|
||||
go runAndTrackGenericTask(ctx, record, task)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// NewCompletedTask returns a Completed task, with result=nil, error=nil
|
||||
func NewCompletedTask() *TaskStatus {
|
||||
return &TaskStatus{
|
||||
func NewCompletedTask[T any](value *T) *Task[T] {
|
||||
return &Task[T]{
|
||||
state: StateCompleted,
|
||||
result: nil,
|
||||
result: value,
|
||||
err: nil,
|
||||
// nil cancelFunc and waitGroup should be protected with IsTerminalState()
|
||||
cancelFunc: nil,
|
||||
|
@ -113,42 +116,7 @@ func NewCompletedTask() *TaskStatus {
|
|||
}
|
||||
}
|
||||
|
||||
// Start run a async function and returns you a handle which you can Wait or Cancel.
|
||||
// context passed in may impact task lifetime (from context cancellation)
|
||||
func Start(ctx context.Context, task AsyncFunc) *TaskStatus {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
record := &TaskStatus{
|
||||
state: StateRunning,
|
||||
result: nil,
|
||||
cancelFunc: cancel,
|
||||
waitGroup: wg,
|
||||
}
|
||||
|
||||
go runAndTrackTask(ctx, record, task)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
// isErrorReallyError do extra error check
|
||||
// - Nil Pointer to a Type (that implement error)
|
||||
// - Zero Value of a Type (that implement error)
|
||||
func isErrorReallyError(err error) bool {
|
||||
v := reflect.ValueOf(err)
|
||||
if v.Type().Kind() == reflect.Ptr &&
|
||||
v.IsNil() {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Type().Kind() == reflect.Struct &&
|
||||
v.IsZero() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx context.Context) (interface{}, error)) {
|
||||
func runAndTrackGenericTask[T any](ctx context.Context, record *Task[T], task func(ctx context.Context) (*T, error)) {
|
||||
defer record.waitGroup.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
|
@ -159,11 +127,7 @@ func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx cont
|
|||
|
||||
result, err := task(ctx)
|
||||
|
||||
if err == nil ||
|
||||
// incase some team use pointer typed error (implement Error() string on a pointer type)
|
||||
// which can break err check (but nil point assigned to error result to non-nil error)
|
||||
// check out TestPointerErrorCase in error_test.go
|
||||
!isErrorReallyError(err) {
|
||||
if err == nil {
|
||||
record.finish(StateCompleted, result, nil)
|
||||
return
|
||||
}
|
||||
|
@ -172,7 +136,7 @@ func runAndTrackTask(ctx context.Context, record *TaskStatus, task func(ctx cont
|
|||
record.finish(StateFailed, result, err)
|
||||
}
|
||||
|
||||
func (t *TaskStatus) finish(state State, result interface{}, err error) {
|
||||
func (t *Task[T]) finish(state State, result *T, err error) {
|
||||
// only update state and result if not yet canceled
|
||||
if !t.state.IsTerminalState() {
|
||||
t.state = state
|
|
@ -9,9 +9,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type notMatter string
|
||||
|
||||
const testContextKey notMatter = "testing"
|
||||
const testContextKey string = "testing"
|
||||
|
||||
func newTestContext(t *testing.T) context.Context {
|
||||
return context.WithValue(context.TODO(), testContextKey, t)
|
||||
|
@ -21,8 +19,8 @@ func newTestContextWithTimeout(t *testing.T, timeout time.Duration) (context.Con
|
|||
return context.WithTimeout(context.WithValue(context.TODO(), testContextKey, t), timeout)
|
||||
}
|
||||
|
||||
func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc {
|
||||
return func(ctx context.Context) (interface{}, error) {
|
||||
func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFunc[int] {
|
||||
return func(ctx context.Context) (*int, error) {
|
||||
t := ctx.Value(testContextKey).(*testing.T)
|
||||
|
||||
result := 0
|
||||
|
@ -33,14 +31,14 @@ func getCountingTask(countTo int, sleepInterval time.Duration) asynctask.AsyncFu
|
|||
result = i
|
||||
case <-ctx.Done():
|
||||
t.Log("work canceled")
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
return &result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestEasyCase(t *testing.T) {
|
||||
func TestEasyGenericCase(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
@ -48,28 +46,26 @@ func TestEasyCase(t *testing.T) {
|
|||
t1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
assert.Equal(t, asynctask.StateRunning, t1.State(), "Task should queued to Running")
|
||||
|
||||
rawResult, err := t1.Wait(ctx)
|
||||
rawResult, err := t1.Result(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
|
||||
assert.NotNil(t, rawResult)
|
||||
result := rawResult.(int)
|
||||
assert.Equal(t, result, 9)
|
||||
assert.Equal(t, *rawResult, 9)
|
||||
|
||||
// wait Again,
|
||||
start := time.Now()
|
||||
rawResult, err = t1.Wait(ctx)
|
||||
rawResult, err = t1.Result(ctx)
|
||||
elapsed := time.Since(start)
|
||||
// nothing should change
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, asynctask.StateCompleted, t1.State(), "Task should complete by now")
|
||||
assert.NotNil(t, rawResult)
|
||||
result = rawResult.(int)
|
||||
assert.Equal(t, result, 9)
|
||||
assert.Equal(t, *rawResult, 9)
|
||||
|
||||
assert.True(t, elapsed.Microseconds() < 3, "Second wait should return immediately")
|
||||
}
|
||||
|
||||
func TestCancelFunc(t *testing.T) {
|
||||
func TestCancelFuncOnGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
@ -80,18 +76,16 @@ func TestCancelFunc(t *testing.T) {
|
|||
time.Sleep(time.Second * 1)
|
||||
t1.Cancel()
|
||||
|
||||
rawResult, err := t1.Wait(ctx)
|
||||
_, err := t1.Result(ctx)
|
||||
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
|
||||
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
|
||||
assert.Nil(t, rawResult)
|
||||
|
||||
// I can cancel again, and nothing changes
|
||||
time.Sleep(time.Second * 1)
|
||||
t1.Cancel()
|
||||
rawResult, err = t1.Wait(ctx)
|
||||
_, err = t1.Result(ctx)
|
||||
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
|
||||
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
|
||||
assert.Nil(t, rawResult)
|
||||
|
||||
// cancel a task shouldn't cancel it's parent context.
|
||||
select {
|
||||
|
@ -102,7 +96,7 @@ func TestCancelFunc(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConsistentResultAfterCancel(t *testing.T) {
|
||||
func TestConsistentResultAfterCancelGenericTask(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
@ -119,24 +113,25 @@ func TestConsistentResultAfterCancel(t *testing.T) {
|
|||
assert.True(t, duration < 1*time.Millisecond, "cancel shouldn't take that long")
|
||||
|
||||
// wait til routine finish
|
||||
rawResult, err := t2.Wait(ctx)
|
||||
rawResult, err := t2.Result(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, asynctask.StateCompleted, t2.State(), "t2 should complete")
|
||||
assert.Equal(t, rawResult, 9)
|
||||
assert.Equal(t, *rawResult, 9)
|
||||
|
||||
// t1 should remain canceled and
|
||||
rawResult, err = t1.Wait(ctx)
|
||||
rawResult, err = t1.Result(ctx)
|
||||
assert.Equal(t, asynctask.ErrCanceled, err, "should return reason of error")
|
||||
assert.Equal(t, asynctask.StateCanceled, t1.State(), "Task should remain in cancel state")
|
||||
assert.Nil(t, rawResult)
|
||||
assert.Equal(t, *rawResult, 0) // default value for int
|
||||
}
|
||||
|
||||
func TestCompletedTask(t *testing.T) {
|
||||
func TestCompletedGenericTask(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
tsk := asynctask.NewCompletedTask()
|
||||
result := "something"
|
||||
tsk := asynctask.NewCompletedTask(&result)
|
||||
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should in CompletedState")
|
||||
|
||||
// nothing should happen
|
||||
|
@ -144,19 +139,19 @@ func TestCompletedTask(t *testing.T) {
|
|||
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
|
||||
|
||||
// you get nil result and nil error
|
||||
result, err := tsk.Wait(ctx)
|
||||
resultGet, err := tsk.Result(ctx)
|
||||
assert.Equal(t, asynctask.StateCompleted, tsk.State(), "Task should still in CompletedState")
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Equal(t, *resultGet, result)
|
||||
}
|
||||
|
||||
func TestCrazyCase(t *testing.T) {
|
||||
func TestCrazyCaseGeneric(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := newTestContextWithTimeout(t, 3*time.Second)
|
||||
defer cancelFunc()
|
||||
|
||||
numOfTasks := 8000 // if you have --race switch on: limit on 8128 simultaneously alive goroutines is exceeded, dying
|
||||
tasks := map[int]*asynctask.TaskStatus{}
|
||||
tasks := map[int]*asynctask.Task[int]{}
|
||||
for i := 0; i < numOfTasks; i++ {
|
||||
tasks[i] = asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
}
|
||||
|
@ -167,17 +162,14 @@ func TestCrazyCase(t *testing.T) {
|
|||
}
|
||||
|
||||
for i := 0; i < numOfTasks; i += 1 {
|
||||
rawResult, err := tasks[i].Wait(ctx)
|
||||
rawResult, err := tasks[i].Result(ctx)
|
||||
|
||||
if i%2 == 0 {
|
||||
assert.Equal(t, asynctask.ErrCanceled, err, "should be canceled")
|
||||
assert.Nil(t, rawResult)
|
||||
assert.Equal(t, *rawResult, 0)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rawResult)
|
||||
|
||||
result := rawResult.(int)
|
||||
assert.Equal(t, result, 9)
|
||||
assert.Equal(t, *rawResult, 9)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package asynctask
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// State of a task.
|
||||
type State string
|
||||
|
||||
// StateRunning indicate task is still running.
|
||||
const StateRunning State = "Running"
|
||||
|
||||
// StateCompleted indicate task is finished.
|
||||
const StateCompleted State = "Completed"
|
||||
|
||||
// StateFailed indicate task failed.
|
||||
const StateFailed State = "Failed"
|
||||
|
||||
// StateCanceled indicate task got canceled.
|
||||
const StateCanceled State = "Canceled"
|
||||
|
||||
// IsTerminalState tells whether the task finished
|
||||
func (s State) IsTerminalState() bool {
|
||||
return s != StateRunning
|
||||
}
|
||||
|
||||
// ErrPanic is returned if panic cought in the task
|
||||
var ErrPanic = errors.New("panic")
|
||||
|
||||
// ErrCanceled is returned if a cancel is triggered
|
||||
var ErrCanceled = errors.New("canceled")
|
10
wait_all.go
10
wait_all.go
|
@ -6,6 +6,10 @@ import (
|
|||
"sync"
|
||||
)
|
||||
|
||||
type Waitable interface {
|
||||
Wait(context.Context) error
|
||||
}
|
||||
|
||||
// WaitAllOptions defines options for WaitAll function
|
||||
type WaitAllOptions struct {
|
||||
// FailFast set to true will indicate WaitAll to return on first error it sees.
|
||||
|
@ -14,7 +18,7 @@ type WaitAllOptions struct {
|
|||
|
||||
// WaitAll block current thread til all task finished.
|
||||
// first error from any tasks passed in will be returned.
|
||||
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus) error {
|
||||
func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...Waitable) error {
|
||||
tasksCount := len(tasks)
|
||||
|
||||
mutex := sync.Mutex{}
|
||||
|
@ -65,8 +69,8 @@ func WaitAll(ctx context.Context, options *WaitAllOptions, tasks ...*TaskStatus)
|
|||
return nil
|
||||
}
|
||||
|
||||
func waitOne(ctx context.Context, tsk *TaskStatus, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
|
||||
_, err := tsk.Wait(ctx)
|
||||
func waitOne(ctx context.Context, tsk Waitable, errorCh chan<- error, errorChClosed *bool, mutex *sync.Mutex) {
|
||||
err := tsk.Wait(ctx)
|
||||
|
||||
// why mutex?
|
||||
// if all tasks start using same context (unittest is good example)
|
||||
|
|
|
@ -18,7 +18,8 @@ func TestWaitAll(t *testing.T) {
|
|||
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
|
||||
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
|
||||
completedTsk := asynctask.NewCompletedTask()
|
||||
result := "something"
|
||||
completedTsk := asynctask.NewCompletedTask(&result)
|
||||
|
||||
start := time.Now()
|
||||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk1, countingTsk2, countingTsk3, completedTsk)
|
||||
|
@ -36,7 +37,8 @@ func TestWaitAllFailFastCase(t *testing.T) {
|
|||
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
|
||||
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
|
||||
completedTsk := asynctask.NewCompletedTask()
|
||||
result := "something"
|
||||
completedTsk := asynctask.NewCompletedTask(&result)
|
||||
|
||||
start := time.Now()
|
||||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: true}, countingTsk, errorTsk, panicTsk, completedTsk)
|
||||
|
@ -61,7 +63,8 @@ func TestWaitAllErrorCase(t *testing.T) {
|
|||
countingTsk := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
errorTsk := asynctask.Start(ctx, getErrorTask("expected error", 10*time.Millisecond))
|
||||
panicTsk := asynctask.Start(ctx, getPanicTask(20*time.Millisecond))
|
||||
completedTsk := asynctask.NewCompletedTask()
|
||||
result := "something"
|
||||
completedTsk := asynctask.NewCompletedTask(&result)
|
||||
|
||||
start := time.Now()
|
||||
err := asynctask.WaitAll(ctx, &asynctask.WaitAllOptions{FailFast: false}, countingTsk, errorTsk, panicTsk, completedTsk)
|
||||
|
@ -86,7 +89,8 @@ func TestWaitAllCanceled(t *testing.T) {
|
|||
countingTsk1 := asynctask.Start(ctx, getCountingTask(10, 200*time.Millisecond))
|
||||
countingTsk2 := asynctask.Start(ctx, getCountingTask(10, 20*time.Millisecond))
|
||||
countingTsk3 := asynctask.Start(ctx, getCountingTask(10, 2*time.Millisecond))
|
||||
completedTsk := asynctask.NewCompletedTask()
|
||||
result := "something"
|
||||
completedTsk := asynctask.NewCompletedTask(&result)
|
||||
|
||||
waitCtx, cancelFunc1 := context.WithTimeout(ctx, 5*time.Millisecond)
|
||||
defer cancelFunc1()
|
||||
|
|
Загрузка…
Ссылка в новой задаче