internal/pipe: new package for handling command pipelines
This commit is contained in:
Родитель
6aa6890117
Коммит
1458ae5f8b
5
go.mod
5
go.mod
|
@ -6,6 +6,7 @@ require (
|
|||
github.com/cli/safeexec v1.0.0
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/stretchr/testify v1.4.0
|
||||
gopkg.in/yaml.v2 v2.2.7 // indirect
|
||||
github.com/stretchr/testify v1.7.0
|
||||
go.uber.org/goleak v1.1.12
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
|
||||
)
|
||||
|
|
50
go.sum
50
go.sum
|
@ -1,19 +1,53 @@
|
|||
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
|
||||
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA=
|
||||
go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
|
||||
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// commandStage is a pipeline `Stage` based on running an external
|
||||
// command and piping the data through its stdin and stdout.
|
||||
type commandStage struct {
|
||||
name string
|
||||
stdin io.Closer
|
||||
cmd *exec.Cmd
|
||||
done chan struct{}
|
||||
wg errgroup.Group
|
||||
stderr bytes.Buffer
|
||||
|
||||
// If the context expired and we attempted to kill the command,
|
||||
// `ctx.Err()` is stored here.
|
||||
ctxErr atomic.Value
|
||||
}
|
||||
|
||||
// Command returns a pipeline `Stage` based on the specified external
|
||||
// `command`, run with the given command-line `args`. Its stdin and
|
||||
// stdout are handled as usual, and its stderr is collected and
|
||||
// included in any `*exec.ExitError` that the command might emit.
|
||||
func Command(command string, args ...string) Stage {
|
||||
if len(command) == 0 {
|
||||
panic("attempt to create command with empty command")
|
||||
}
|
||||
|
||||
cmd := exec.Command(command, args...)
|
||||
return CommandStage(command, cmd)
|
||||
}
|
||||
|
||||
// Command returns a pipeline `Stage` with the name `name`, based on
|
||||
// the specified `cmd`. Its stdin and stdout are handled as usual, and
|
||||
// its stderr is collected and included in any `*exec.ExitError` that
|
||||
// the command might emit.
|
||||
func CommandStage(name string, cmd *exec.Cmd) Stage {
|
||||
return &commandStage{
|
||||
name: name,
|
||||
cmd: cmd,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *commandStage) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
func (s *commandStage) Start(
|
||||
ctx context.Context, env Env, stdin io.ReadCloser,
|
||||
) (io.ReadCloser, error) {
|
||||
if s.cmd.Dir == "" {
|
||||
s.cmd.Dir = env.Dir
|
||||
}
|
||||
|
||||
if stdin != nil {
|
||||
s.cmd.Stdin = stdin
|
||||
// Also keep a copy so that we can close it when the command exits:
|
||||
s.stdin = stdin
|
||||
}
|
||||
|
||||
stdout, err := s.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the caller hasn't arranged otherwise, read the command's
|
||||
// standard error into our `stderr` field:
|
||||
if s.cmd.Stderr == nil {
|
||||
// We can't just set `s.cmd.Stderr = &s.stderr`, because if we
|
||||
// do then `s.cmd.Wait()` doesn't wait to be sure that all
|
||||
// error output has been captured. By doing this ourselves, we
|
||||
// can be sure.
|
||||
p, err := s.cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.wg.Go(func() error {
|
||||
_, err := io.Copy(&s.stderr, p)
|
||||
// We don't consider `ErrClosed` an error (FIXME: is this
|
||||
// correct?):
|
||||
if err != nil && !errors.Is(err, os.ErrClosed) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Put the command in its own process group:
|
||||
if s.cmd.SysProcAttr == nil {
|
||||
s.cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
s.cmd.SysProcAttr.Setpgid = true
|
||||
|
||||
if err := s.cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Arrange for the process to be killed (gently) if the context
|
||||
// expires before the command exits normally:
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.kill(ctx.Err())
|
||||
case <-s.done:
|
||||
// Process already done; no need to kill anything.
|
||||
}
|
||||
}()
|
||||
|
||||
return stdout, nil
|
||||
}
|
||||
|
||||
// kill is called to kill the process if the context expires. `err` is
|
||||
// the corresponding value of `Context.Err()`.
|
||||
func (s *commandStage) kill(err error) {
|
||||
// I believe that the calls to `syscall.Kill()` in this method are
|
||||
// racy. It could be that s.cmd.Wait() succeeds immediately before
|
||||
// this call, in which case the process group wouldn't exist
|
||||
// anymore. But I don't see any way to avoid this without
|
||||
// duplicating a lot of code from `exec.Cmd`. (`os.Cmd.Kill()` and
|
||||
// `os.Cmd.Signal()` appear to be race-free, but only because they
|
||||
// use internal synchronization. But those methods only kill the
|
||||
// process, not the process group, so they are not suitable here.
|
||||
|
||||
// We started the process with PGID == PID:
|
||||
pid := s.cmd.Process.Pid
|
||||
select {
|
||||
case <-s.done:
|
||||
// Process has ended; no need to kill it again.
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Record the `ctx.Err()`, which will be used as the error result
|
||||
// for this stage.
|
||||
s.ctxErr.Store(err)
|
||||
|
||||
// First try to kill using a relatively gentle signal so that
|
||||
// the processes have a chance to clean up after themselves:
|
||||
_ = syscall.Kill(-pid, syscall.SIGTERM)
|
||||
|
||||
// Well-behaved processes should commit suicide after the above,
|
||||
// but if they don't exit within 2s, murder the whole lot of them:
|
||||
go func() {
|
||||
// Use an explicit `time.Timer` rather than `time.After()` so
|
||||
// that we can stop it (freeing resources) promptly if the
|
||||
// command exits before the timer triggers.
|
||||
timer := time.NewTimer(2 * time.Second)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
// Process has ended; no need to kill it again.
|
||||
case <-timer.C:
|
||||
_ = syscall.Kill(-pid, syscall.SIGKILL)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// filterCmdError interprets `err`, which was returned by `Cmd.Wait()`
|
||||
// (possibly `nil`), possibly modifying it or ignoring it. It returns
|
||||
// the error that should actually be returned to the caller (possibly
|
||||
// `nil`).
|
||||
func (s *commandStage) filterCmdError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
eErr, ok := err.(*exec.ExitError)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
ctxErr, ok := s.ctxErr.Load().(error)
|
||||
if ok {
|
||||
// If the process looks like it was killed by us, substitute
|
||||
// `ctxErr` for the process's own exit error.
|
||||
ps, ok := eErr.ProcessState.Sys().(syscall.WaitStatus)
|
||||
if ok && ps.Signaled() &&
|
||||
(ps.Signal() == syscall.SIGTERM || ps.Signal() == syscall.SIGKILL) {
|
||||
return ctxErr
|
||||
}
|
||||
}
|
||||
|
||||
eErr.Stderr = s.stderr.Bytes()
|
||||
return eErr
|
||||
}
|
||||
|
||||
func (s *commandStage) Wait() error {
|
||||
defer close(s.done)
|
||||
|
||||
// Make sure that any stderr is copied before `s.cmd.Wait()`
|
||||
// closes the read end of the pipe:
|
||||
wErr := s.wg.Wait()
|
||||
|
||||
err := s.cmd.Wait()
|
||||
err = s.filterCmdError(err)
|
||||
|
||||
if err == nil && wErr != nil {
|
||||
err = wErr
|
||||
}
|
||||
|
||||
if s.stdin != nil {
|
||||
cErr := s.stdin.Close()
|
||||
if cErr != nil && err == nil {
|
||||
return cErr
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// ErrorFilter is a function that can filter errors from
|
||||
// `Stage.Wait()`. The original error (possibly nil) is passed in as
|
||||
// an argument, and whatever the function returns is the error
|
||||
// (possibly nil) that is actually emitted.
|
||||
type ErrorFilter func(err error) error
|
||||
|
||||
func FilterError(s Stage, filter ErrorFilter) Stage {
|
||||
return efStage{Stage: s, filter: filter}
|
||||
}
|
||||
|
||||
type efStage struct {
|
||||
Stage
|
||||
filter ErrorFilter
|
||||
}
|
||||
|
||||
func (s efStage) Wait() error {
|
||||
return s.filter(s.Stage.Wait())
|
||||
}
|
||||
|
||||
// ErrorMatcher decides whether its argument matches some class of
|
||||
// errors (e.g., errors that we want to ignore). The function will
|
||||
// only be invoked for non-nil errors.
|
||||
type ErrorMatcher func(err error) bool
|
||||
|
||||
// IgnoreError creates a stage that acts like `s` except that it
|
||||
// ignores any errors that are matched by `em`. Use like
|
||||
//
|
||||
// p.Add(pipe.IgnoreError(
|
||||
// someStage,
|
||||
// func(err error) bool {
|
||||
// var myError *MyErrorType
|
||||
// return errors.As(err, &myError) && myError.foo == 42
|
||||
// },
|
||||
// )
|
||||
//
|
||||
// The second argument can also be one of the `ErrorMatcher`s that are
|
||||
// provided by this package (e.g., `IsError(target)`,
|
||||
// IsSignal(signal), `IsSIGPIPE`, `IsEPIPE`, `IsPipeError`), or one of
|
||||
// the functions from the standard library that has the same signature
|
||||
// (e.g., `os.IsTimeout`), or some combination of these (e.g.,
|
||||
// `AnyError(IsSIGPIPE, os.IsTimeout)`).
|
||||
func IgnoreError(s Stage, em ErrorMatcher) Stage {
|
||||
return FilterError(s,
|
||||
func(err error) error {
|
||||
if err == nil || em(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// AnyError returns an `ErrorMatcher` that returns true for an error
|
||||
// that matches any of the `ems`.
|
||||
func AnyError(ems ...ErrorMatcher) ErrorMatcher {
|
||||
return func(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
for _, em := range ems {
|
||||
if em(err) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsError returns an ErrorIdentifier for the specified target error,
|
||||
// matched using `errors.Is()`. Use like
|
||||
//
|
||||
// p.Add(pipe.IgnoreError(someStage, IsError(io.EOF)))
|
||||
func IsError(target error) ErrorMatcher {
|
||||
return func(err error) bool {
|
||||
return errors.Is(err, target)
|
||||
}
|
||||
}
|
||||
|
||||
// IsSIGPIPE returns an `ErrorMatcher` that matches `*exec.ExitError`s
|
||||
// that were caused by the specified signal. The match for
|
||||
// `*exec.ExitError`s uses `errors.As()`.
|
||||
func IsSignal(signal syscall.Signal) ErrorMatcher {
|
||||
return func(err error) bool {
|
||||
var eErr *exec.ExitError
|
||||
|
||||
if !errors.As(err, &eErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
status, ok := eErr.Sys().(syscall.WaitStatus)
|
||||
return ok && status.Signaled() && status.Signal() == signal
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// IsSIGPIPE is an `ErrorMatcher` that matches `*exec.ExitError`s
|
||||
// that were caused by SIGPIPE. The match for `*exec.ExitError`s
|
||||
// uses `errors.As()`. Use like
|
||||
//
|
||||
// p.Add(IgnoreError(someStage, IsSIGPIPE))
|
||||
IsSIGPIPE = IsSignal(syscall.SIGPIPE)
|
||||
|
||||
// IsEPIPE is an `ErrorMatcher` that matches `syscall.EPIPE` using
|
||||
// `errors.Is()`. Use like
|
||||
//
|
||||
// p.Add(IgnoreError(someStage, IsEPIPE))
|
||||
IsEPIPE = IsError(syscall.EPIPE)
|
||||
|
||||
// IsErrClosedPipe is an `ErrorMatcher` that matches
|
||||
// `io.ErrClosedPipe` using `errors.Is()`. (`io.ErrClosedPipe` is
|
||||
// the error that results from writing to a closed
|
||||
// `*io.PipeWriter`.) Use like
|
||||
//
|
||||
// p.Add(IgnoreError(someStage, IsErrClosedPipe))
|
||||
IsErrClosedPipe = IsError(io.ErrClosedPipe)
|
||||
|
||||
// IsPipeError is an `ErrorMatcher` that matches a few different
|
||||
// errors that typically result if a stage writes to a subsequent
|
||||
// stage that has stopped reading from its stdin. Use like
|
||||
//
|
||||
// p.Add(IgnoreError(someStage, IsPipeError))
|
||||
IsPipeError = AnyError(IsSIGPIPE, IsEPIPE, IsErrClosedPipe)
|
||||
)
|
|
@ -0,0 +1,66 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// StageFunc is a function that can be used to power a `goStage`. It
|
||||
// should read its input from `stdin` and write its output to
|
||||
// `stdout`. `stdin` and `stdout` will be closed automatically (if
|
||||
// necessary) once the function returns.
|
||||
//
|
||||
// Neither `stdin` nor `stdout` are necessarily buffered. If the
|
||||
// `StageFunc` requires buffering, it needs to arrange that itself.
|
||||
//
|
||||
// A `StageFunc` is run in a separate goroutine, so it must be careful
|
||||
// to synchronize any data access aside from reading and writing.
|
||||
type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) error
|
||||
|
||||
// Function returns a pipeline `Stage` that will run a `StageFunc` in
|
||||
// a separate goroutine to process the data. See `StageFunc` for more
|
||||
// information.
|
||||
func Function(name string, f StageFunc) Stage {
|
||||
return &goStage{
|
||||
name: name,
|
||||
f: f,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// goStage is a `Stage` that does its work by running an arbitrary
|
||||
// `stageFunc` in a goroutine.
|
||||
type goStage struct {
|
||||
name string
|
||||
f StageFunc
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *goStage) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
|
||||
r, w := io.Pipe()
|
||||
go func() {
|
||||
s.err = s.f(ctx, env, stdin, w)
|
||||
if err := w.Close(); err != nil && s.err == nil {
|
||||
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
|
||||
}
|
||||
if stdin != nil {
|
||||
if err := stdin.Close(); err != nil && s.err == nil {
|
||||
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
|
||||
}
|
||||
}
|
||||
close(s.done)
|
||||
}()
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *goStage) Wait() error {
|
||||
<-s.done
|
||||
return s.err
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ioCopier is a stage that copies its stdin to a specified
|
||||
// `io.Writer`. It generates no stdout itself.
|
||||
type ioCopier struct {
|
||||
w io.WriteCloser
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newIOCopier(w io.WriteCloser) *ioCopier {
|
||||
return &ioCopier{
|
||||
w: w,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ioCopier) Name() string {
|
||||
return "ioCopier"
|
||||
}
|
||||
|
||||
// This method always returns `nil, nil`.
|
||||
func (s *ioCopier) Start(ctx context.Context, _ Env, r io.ReadCloser) (io.ReadCloser, error) {
|
||||
go func() {
|
||||
_, err := io.Copy(s.w, r)
|
||||
// We don't consider `ErrClosed` an error (FIXME: is this
|
||||
// correct?):
|
||||
if err != nil && !errors.Is(err, os.ErrClosed) {
|
||||
s.err = err
|
||||
}
|
||||
if err := r.Close(); err != nil && s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
if err := s.w.Close(); err != nil && s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
close(s.done)
|
||||
}()
|
||||
|
||||
// FIXME: if `s.w.Write()` is blocking (e.g., because there is a
|
||||
// downstream process that is not reading from the other side),
|
||||
// there's no way to terminate the copy when the context expires.
|
||||
// This is not too bad, because the `io.Copy()` call will exit by
|
||||
// itself when its input is closed.
|
||||
//
|
||||
// We could, however, be smarter about exiting more quickly if the
|
||||
// context expires but `s.w.Write()` is not blocking.
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *ioCopier) Wait() error {
|
||||
<-s.done
|
||||
return s.err
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LinewiseStageFunc is a function that can be embedded in a
|
||||
// `goStage`. It is called once per line in the input (where "line"
|
||||
// can be defined via any `bufio.Scanner`). It should process the line
|
||||
// and may write whatever it likes to `stdout`, which is a buffered
|
||||
// writer whose contents are forwarded to the input of the next stage
|
||||
// of the pipeline. The function needn't write one line of output per
|
||||
// line of input.
|
||||
//
|
||||
// The function mustn't retain copies of `line`, since it may be
|
||||
// overwritten every time the function is called.
|
||||
//
|
||||
// The function needn't flush or close `stdout` (this will be done
|
||||
// automatically when all of the input has been processed).
|
||||
//
|
||||
// If there is an error parsing the input into lines, or if this
|
||||
// function returns an error, then the whole pipeline will be aborted
|
||||
// with that error. However, if the function returns the special error
|
||||
// `pipe.FinishEarly`, the stage will stop processing immediately with
|
||||
// a `nil` error value.
|
||||
//
|
||||
// The function will be called in a separate goroutine, so it must be
|
||||
// careful to synchronize any data access aside from writing to
|
||||
// `stdout`.
|
||||
type LinewiseStageFunc func(
|
||||
ctx context.Context, env Env, line []byte, stdout *bufio.Writer,
|
||||
) error
|
||||
|
||||
// LinewiseFunction returns a function-based `Stage`. The input will
|
||||
// be split into LF-terminated lines and passed to the function one
|
||||
// line at a time (without the LF). The function may emit output to
|
||||
// its `stdout` argument. See the definition of `LinewiseStageFunc`
|
||||
// for more information.
|
||||
//
|
||||
// Note that the stage will emit an error if any line (including its
|
||||
// end-of-line terminator) exceeds 64 kiB in length. If this is too
|
||||
// short, use `ScannerFunction()` directly with your own
|
||||
// `NewScannerFunc` as argument, or use `Function()` directly with
|
||||
// your own `StageFunc`.
|
||||
func LinewiseFunction(name string, f LinewiseStageFunc) Stage {
|
||||
return ScannerFunction(
|
||||
name,
|
||||
func(r io.Reader) (Scanner, error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Split based on strict LF (we don't accept CRLF):
|
||||
scanner.Split(ScanLFTerminatedLines)
|
||||
return scanner, nil
|
||||
},
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
// ScanLFTerminatedLines is a `bufio.SplitFunc` that splits its input
|
||||
// into lines at LF characters (not treating CR specially).
|
||||
func ScanLFTerminatedLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, '\n'); i != -1 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Env represents the environment that a pipeline stage should run in.
|
||||
// It is passed to `Stage.Start()`.
|
||||
type Env struct {
|
||||
// The directory in which external commands should be executed by
|
||||
// default.
|
||||
Dir string
|
||||
}
|
||||
|
||||
// Pipeline represents a Unix-like pipe that can include multiple
|
||||
// stages, including external processes but also and stages written in
|
||||
// Go.
|
||||
type Pipeline struct {
|
||||
env Env
|
||||
|
||||
stdin io.Reader
|
||||
stdout io.WriteCloser
|
||||
stages []Stage
|
||||
cancel func()
|
||||
|
||||
// Atomically written and read value, nonzero if the pipeline has
|
||||
// been started. This is only used for lifecycle sanity checks but
|
||||
// does not guarantee that clients are using the class correctly.
|
||||
started uint32
|
||||
}
|
||||
|
||||
type nopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w nopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewPipeline returns a Pipeline struct with all of the `options`
|
||||
// applied.
|
||||
func New(options ...Option) *Pipeline {
|
||||
p := &Pipeline{}
|
||||
|
||||
for _, option := range options {
|
||||
option(p)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Option is a type alias for Pipeline functional options.
|
||||
type Option func(*Pipeline)
|
||||
|
||||
// WithDir sets the default directory for running external commands.
|
||||
func WithDir(dir string) Option {
|
||||
return func(p *Pipeline) {
|
||||
p.env.Dir = dir
|
||||
}
|
||||
}
|
||||
|
||||
// WithStdin assigns stdin to the first command in the pipeline.
|
||||
func WithStdin(stdin io.Reader) Option {
|
||||
return func(p *Pipeline) {
|
||||
p.stdin = stdin
|
||||
}
|
||||
}
|
||||
|
||||
// WithStdout assigns stdout to the last command in the pipeline.
|
||||
func WithStdout(stdout io.Writer) Option {
|
||||
return func(p *Pipeline) {
|
||||
p.stdout = nopWriteCloser{stdout}
|
||||
}
|
||||
}
|
||||
|
||||
// WithStdoutCloser assigns stdout to the last command in the
|
||||
// pipeline, and closes stdout when it's done.
|
||||
func WithStdoutCloser(stdout io.WriteCloser) Option {
|
||||
return func(p *Pipeline) {
|
||||
p.stdout = stdout
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pipeline) hasStarted() bool {
|
||||
return atomic.LoadUint32(&p.started) != 0
|
||||
}
|
||||
|
||||
// Add appends one or more stages to the pipeline.
|
||||
func (p *Pipeline) Add(stages ...Stage) {
|
||||
if p.hasStarted() {
|
||||
panic("attempt to modify a pipeline that has already started")
|
||||
}
|
||||
|
||||
p.stages = append(p.stages, stages...)
|
||||
}
|
||||
|
||||
// AddWithIgnoredError appends one or more stages that are ignoring
|
||||
// the passed in error to the pipeline.
|
||||
func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) {
|
||||
if p.hasStarted() {
|
||||
panic("attempt to modify a pipeline that has already started")
|
||||
}
|
||||
|
||||
for _, stage := range stages {
|
||||
p.stages = append(p.stages, IgnoreError(stage, em))
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the commands in the pipeline. If `Start()` exits
|
||||
// without an error, `Wait()` must also be called, to allow all
|
||||
// resources to be freed.
|
||||
func (p *Pipeline) Start(ctx context.Context) error {
|
||||
if p.hasStarted() {
|
||||
panic("attempt to start a pipeline that has already started")
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&p.started, 1)
|
||||
ctx, p.cancel = context.WithCancel(ctx)
|
||||
|
||||
var nextStdin io.ReadCloser
|
||||
if p.stdin != nil {
|
||||
// We don't want the first stage to actually close this, and
|
||||
// it's not even an `io.ReadCloser`, so fake it:
|
||||
nextStdin = ioutil.NopCloser(p.stdin)
|
||||
}
|
||||
|
||||
for i, s := range p.stages {
|
||||
var err error
|
||||
stdout, err := s.Start(ctx, p.env, nextStdin)
|
||||
if err != nil {
|
||||
// Close the pipe that the previous stage was writing to.
|
||||
// That should cause it to exit even if it's not minding
|
||||
// its context.
|
||||
if nextStdin != nil {
|
||||
_ = nextStdin.Close()
|
||||
}
|
||||
|
||||
// Kill and wait for any stages that have been started
|
||||
// already to finish:
|
||||
p.cancel()
|
||||
for _, s := range p.stages[:i] {
|
||||
_ = s.Wait()
|
||||
}
|
||||
return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err)
|
||||
}
|
||||
nextStdin = stdout
|
||||
}
|
||||
|
||||
// If the pipeline was configured with a `stdout`, add a synthetic
|
||||
// stage to copy the last stage's stdout to that writer:
|
||||
if p.stdout != nil {
|
||||
c := newIOCopier(p.stdout)
|
||||
p.stages = append(p.stages, c)
|
||||
// `ioCopier.Start()` never fails:
|
||||
_, _ = c.Start(ctx, p.env, nextStdin)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipeline) Output(ctx context.Context) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
p.stdout = nopWriteCloser{&buf}
|
||||
err := p.Run(ctx)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
// Wait waits for each stage in the pipeline to exit.
|
||||
func (p *Pipeline) Wait() error {
|
||||
if !p.hasStarted() {
|
||||
panic("unable to wait on a pipeline that has not started")
|
||||
}
|
||||
|
||||
// Make sure that all of the cleanup eventually happens:
|
||||
defer p.cancel()
|
||||
|
||||
var earliestStageErr error
|
||||
var earliestFailedStage Stage
|
||||
|
||||
for i := len(p.stages) - 1; i >= 0; i-- {
|
||||
s := p.stages[i]
|
||||
err := s.Wait()
|
||||
if err != nil {
|
||||
// Overwrite any existing values here so that we end up
|
||||
// retaining the last error that we see; i.e., the error
|
||||
// that happened earliest in the pipeline.
|
||||
earliestStageErr = err
|
||||
earliestFailedStage = s
|
||||
}
|
||||
}
|
||||
|
||||
if earliestStageErr != nil {
|
||||
return fmt.Errorf("%s: %w", earliestFailedStage.Name(), earliestStageErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run starts and waits for the commands in the pipeline.
|
||||
func (p *Pipeline) Run(ctx context.Context) error {
|
||||
if err := p.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.Wait()
|
||||
}
|
|
@ -0,0 +1,664 @@
|
|||
package pipe_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"github.com/github/git-sizer/internal/pipe"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Check whether this package's test suite leaks any goroutines:
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestPipelineFirstStageFailsToStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
startErr := errors.New("foo")
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
ErrorStartingStage{startErr},
|
||||
ErrorStartingStage{errors.New("this error should never happen")},
|
||||
)
|
||||
assert.ErrorIs(t, p.Run(ctx), startErr)
|
||||
}
|
||||
|
||||
func TestPipelineSecondStageFailsToStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
startErr := errors.New("foo")
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
seqFunction(20000),
|
||||
ErrorStartingStage{startErr},
|
||||
)
|
||||
assert.ErrorIs(t, p.Run(ctx), startErr)
|
||||
}
|
||||
|
||||
func TestPipelineSingleCommandOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(pipe.Command("echo", "hello world"))
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "hello world\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineSingleCommandWithStdout(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
stdout := &bytes.Buffer{}
|
||||
|
||||
p := pipe.New(pipe.WithStdout(stdout))
|
||||
p.Add(pipe.Command("echo", "hello world"))
|
||||
if assert.NoError(t, p.Run(ctx)) {
|
||||
assert.Equal(t, "hello world\n", stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNontrivialPipeline(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("echo", "hello world"),
|
||||
pipe.Command("sed", "s/hello/goodbye/"),
|
||||
)
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "goodbye world\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineReadFromSlowly(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r, w := io.Pipe()
|
||||
|
||||
var buf []byte
|
||||
readErr := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
var err error
|
||||
buf, err = ioutil.ReadAll(r)
|
||||
readErr <- err
|
||||
}()
|
||||
|
||||
p := pipe.New(pipe.WithStdout(w))
|
||||
p.Add(pipe.Command("echo", "hello world"))
|
||||
assert.NoError(t, p.Run(ctx))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// It's not super-intuitive, but `w` has to be closed here so that
|
||||
// the `ioutil.ReadAll()` call above knows that it's done:
|
||||
_ = w.Close()
|
||||
|
||||
assert.NoError(t, <-readErr)
|
||||
assert.Equal(t, "hello world\n", string(buf))
|
||||
}
|
||||
|
||||
func TestPipelineReadFromSlowly2(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
r, w := io.Pipe()
|
||||
|
||||
var buf []byte
|
||||
readErr := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
for {
|
||||
var c [1]byte
|
||||
_, err := r.Read(c[:])
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
readErr <- nil
|
||||
return
|
||||
}
|
||||
readErr <- err
|
||||
return
|
||||
}
|
||||
buf = append(buf, c[0])
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
p := pipe.New(pipe.WithStdout(w))
|
||||
p.Add(pipe.Command("seq", "100"))
|
||||
assert.NoError(t, p.Run(ctx))
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
// It's not super-intuitive, but `w` has to be closed here so that
|
||||
// the `ioutil.ReadAll()` call above knows that it's done:
|
||||
_ = w.Close()
|
||||
|
||||
assert.NoError(t, <-readErr)
|
||||
assert.Equal(t, 292, len(buf))
|
||||
}
|
||||
|
||||
func TestPipelineTwoCommandsPiping(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(pipe.Command("echo", "hello world"))
|
||||
assert.Panics(t, func() { p.Add(pipe.Command("")) })
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "hello world\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
wdir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
dir, err := ioutil.TempDir(wdir, "pipeline-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
p := pipe.New(pipe.WithDir(dir))
|
||||
p.Add(pipe.Command("pwd"))
|
||||
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, dir, strings.TrimSuffix(string(out), "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineExit(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("false"),
|
||||
pipe.Command("true"),
|
||||
)
|
||||
assert.EqualError(t, p.Run(ctx), "false: exit status 1")
|
||||
}
|
||||
|
||||
func TestPipelineStderr(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
dir, err := ioutil.TempDir("", "pipeline-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
p := pipe.New(pipe.WithDir(dir))
|
||||
p.Add(pipe.Command("ls", "doesnotexist"))
|
||||
|
||||
_, err = p.Output(ctx)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "ls: exit status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipelineInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
stdout := &bytes.Buffer{}
|
||||
|
||||
p := pipe.New(pipe.WithStdout(stdout))
|
||||
p.Add(pipe.Command("sleep", "10"))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := p.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = p.Wait()
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestPipelineCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stdout := &bytes.Buffer{}
|
||||
|
||||
p := pipe.New(pipe.WithStdout(stdout))
|
||||
p.Add(pipe.Command("sleep", "10"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
err := p.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
|
||||
err = p.Wait()
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
// Verify the correct error if a command in the pipeline exits before
|
||||
// reading all of its predecessor's output. Note that the amount of
|
||||
// unread output in this case *does fit* within the OS-level pipe
|
||||
// buffer.
|
||||
func TestLittleEPIPE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("sh", "-c", "sleep 1; echo foo"),
|
||||
pipe.Command("true"),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
err := p.Run(ctx)
|
||||
assert.EqualError(t, err, "sh: signal: broken pipe")
|
||||
}
|
||||
|
||||
// Verify the correct error if one command in the pipeline exits
|
||||
// before reading all of its predecessor's output. Note that the
|
||||
// amount of unread output in this case *does not fit* within the
|
||||
// OS-level pipe buffer.
|
||||
func TestBigEPIPE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("seq", "100000"),
|
||||
pipe.Command("true"),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
err := p.Run(ctx)
|
||||
assert.EqualError(t, err, "seq: signal: broken pipe")
|
||||
}
|
||||
|
||||
// Verify the correct error if one command in the pipeline exits
|
||||
// before reading all of its predecessor's output. Note that the
|
||||
// amount of unread output in this case *does not fit* within the
|
||||
// OS-level pipe buffer.
|
||||
func TestIgnoredSIGPIPE(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE),
|
||||
pipe.Command("echo", "foo"),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
out, err := p.Output(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "foo\n", out)
|
||||
}
|
||||
|
||||
func TestFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Print("hello world"),
|
||||
pipe.Function(
|
||||
"farewell",
|
||||
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
|
||||
buf, err := ioutil.ReadAll(stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(buf) != "hello world" {
|
||||
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
|
||||
}
|
||||
_, err = stdout.Write([]byte("goodbye, cruel world"))
|
||||
return err
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
out, err := p.Output(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "goodbye, cruel world", out)
|
||||
}
|
||||
|
||||
func TestPipelineWithFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("echo", "-n", "hello world"),
|
||||
pipe.Function(
|
||||
"farewell",
|
||||
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
|
||||
buf, err := ioutil.ReadAll(stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(buf) != "hello world" {
|
||||
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
|
||||
}
|
||||
_, err = stdout.Write([]byte("goodbye, cruel world"))
|
||||
return err
|
||||
},
|
||||
),
|
||||
pipe.Command("tr", "a-z", "A-Z"),
|
||||
)
|
||||
|
||||
out, err := p.Output(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "GOODBYE, CRUEL WORLD", out)
|
||||
}
|
||||
|
||||
type ErrorStartingStage struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (s ErrorStartingStage) Name() string {
|
||||
return "errorStartingStage"
|
||||
}
|
||||
|
||||
func (s ErrorStartingStage) Start(
|
||||
ctx context.Context, env pipe.Env, stdin io.ReadCloser,
|
||||
) (io.ReadCloser, error) {
|
||||
return ioutil.NopCloser(&bytes.Buffer{}), s.err
|
||||
}
|
||||
|
||||
func (s ErrorStartingStage) Wait() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func seqFunction(n int) pipe.Stage {
|
||||
return pipe.Function(
|
||||
"seq",
|
||||
func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error {
|
||||
for i := 1; i <= n; i++ {
|
||||
_, err := fmt.Fprintf(stdout, "%d\n", i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestPipelineWithLinewiseFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
// Print the numbers from 1 to 20 (generated from scratch):
|
||||
p.Add(
|
||||
seqFunction(20),
|
||||
// Discard all but the multiples of 5, and emit the results
|
||||
// separated by spaces on one line:
|
||||
pipe.LinewiseFunction(
|
||||
"multiples-of-5",
|
||||
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
|
||||
n, err := strconv.Atoi(string(line))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n%5 != 0 {
|
||||
return nil
|
||||
}
|
||||
_, err = fmt.Fprintf(w, " %d", n)
|
||||
return err
|
||||
},
|
||||
),
|
||||
// Read the words and square them, emitting the results one per
|
||||
// line:
|
||||
pipe.ScannerFunction(
|
||||
"square-multiples-of-5",
|
||||
func(r io.Reader) (pipe.Scanner, error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Split(bufio.ScanWords)
|
||||
return scanner, nil
|
||||
},
|
||||
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
|
||||
n, err := strconv.Atoi(string(line))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(w, "%d\n", n*n)
|
||||
return err
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
out, err := p.Output(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "25\n100\n225\n400\n", out)
|
||||
}
|
||||
|
||||
func TestScannerAlwaysFlushes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
var length int64
|
||||
|
||||
p := pipe.New()
|
||||
// Print the numbers from 1 to 20 (generated from scratch):
|
||||
p.Add(
|
||||
pipe.IgnoreError(
|
||||
seqFunction(20),
|
||||
pipe.IsPipeError,
|
||||
),
|
||||
// Pass the numbers through up to 7, then exit with an
|
||||
// ignored error:
|
||||
pipe.IgnoreError(
|
||||
pipe.LinewiseFunction(
|
||||
"error-after-7",
|
||||
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
|
||||
fmt.Fprintf(w, "%s\n", line)
|
||||
if string(line) == "7" {
|
||||
return errors.New("ignore")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
),
|
||||
func(err error) bool {
|
||||
return err.Error() == "ignore"
|
||||
},
|
||||
),
|
||||
// Read the numbers and add them into the sum:
|
||||
pipe.Function(
|
||||
"compute-length",
|
||||
func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error {
|
||||
var err error
|
||||
length, err = io.Copy(ioutil.Discard, stdin)
|
||||
return err
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
err := p.Run(ctx)
|
||||
assert.NoError(t, err)
|
||||
// Make sure that all of the bytes emitted before the second
|
||||
// stage's error were received by the third stage:
|
||||
assert.EqualValues(t, 14, length)
|
||||
}
|
||||
|
||||
func TestScannerFinishEarly(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
var length int64
|
||||
|
||||
p := pipe.New()
|
||||
// Print the numbers from 1 to 20 (generated from scratch):
|
||||
p.Add(
|
||||
pipe.IgnoreError(
|
||||
seqFunction(20),
|
||||
pipe.IsPipeError,
|
||||
),
|
||||
// Pass the numbers through up to 7, then exit with an
|
||||
// ignored error:
|
||||
pipe.LinewiseFunction(
|
||||
"finish-after-7",
|
||||
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
|
||||
fmt.Fprintf(w, "%s\n", line)
|
||||
if string(line) == "7" {
|
||||
return pipe.FinishEarly
|
||||
}
|
||||
return nil
|
||||
},
|
||||
),
|
||||
// Read the numbers and add them into the sum:
|
||||
pipe.Function(
|
||||
"compute-length",
|
||||
func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error {
|
||||
var err error
|
||||
length, err = io.Copy(ioutil.Discard, stdin)
|
||||
return err
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
err := p.Run(ctx)
|
||||
assert.NoError(t, err)
|
||||
// Make sure that all of the bytes emitted before the second
|
||||
// stage's error were received by the third stage:
|
||||
assert.EqualValues(t, 14, length)
|
||||
}
|
||||
|
||||
func TestPrintln(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(pipe.Println("Look Ma, no hands!"))
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "Look Ma, no hands!\n", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintf(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
p := pipe.New()
|
||||
p.Add(pipe.Printf("Strangely recursive: %T", p))
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(t, err) {
|
||||
assert.EqualValues(t, "Strangely recursive: *pipe.Pipeline", out)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSingleProgram(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("true"),
|
||||
)
|
||||
assert.NoError(b, p.Run(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTenPrograms(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("echo", "hello world"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
pipe.Command("cat"),
|
||||
)
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(b, err) {
|
||||
assert.EqualValues(b, "hello world\n", out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTenFunctions(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Println("hello world"),
|
||||
pipe.Function("copy1", catFn),
|
||||
pipe.Function("copy2", catFn),
|
||||
pipe.Function("copy3", catFn),
|
||||
pipe.Function("copy4", catFn),
|
||||
pipe.Function("copy5", catFn),
|
||||
pipe.Function("copy6", catFn),
|
||||
pipe.Function("copy7", catFn),
|
||||
pipe.Function("copy8", catFn),
|
||||
pipe.Function("copy9", catFn),
|
||||
)
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(b, err) {
|
||||
assert.EqualValues(b, "hello world\n", out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTenMixedStages(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
p := pipe.New()
|
||||
p.Add(
|
||||
pipe.Command("echo", "hello world"),
|
||||
pipe.Function("copy1", catFn),
|
||||
pipe.Command("cat"),
|
||||
pipe.Function("copy2", catFn),
|
||||
pipe.Command("cat"),
|
||||
pipe.Function("copy3", catFn),
|
||||
pipe.Command("cat"),
|
||||
pipe.Function("copy4", catFn),
|
||||
pipe.Command("cat"),
|
||||
pipe.Function("copy5", catFn),
|
||||
)
|
||||
out, err := p.Output(ctx)
|
||||
if assert.NoError(b, err) {
|
||||
assert.EqualValues(b, "hello world\n", out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func catFn(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
|
||||
_, err := io.Copy(stdout, stdin)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func Print(a ...interface{}) Stage {
|
||||
return Function(
|
||||
"print",
|
||||
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
|
||||
_, err := fmt.Fprint(stdout, a...)
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func Println(a ...interface{}) Stage {
|
||||
return Function(
|
||||
"println",
|
||||
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
|
||||
_, err := fmt.Fprintln(stdout, a...)
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func Printf(format string, a ...interface{}) Stage {
|
||||
return Function(
|
||||
"printf",
|
||||
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
|
||||
_, err := fmt.Fprintf(stdout, format, a...)
|
||||
return err
|
||||
},
|
||||
)
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Scanner defines the interface (which is implemented by
|
||||
// `bufio.Scanner`) that is needed by `AddScannerFunction()`. See
|
||||
// `bufio.Scanner` for how these methods should behave.
|
||||
type Scanner interface {
|
||||
Scan() bool
|
||||
Bytes() []byte
|
||||
Err() error
|
||||
}
|
||||
|
||||
// FinishEarly is an error that can be returned by a
|
||||
// `LinewiseStageFunc` to request that the iteration be ended early,
|
||||
// without an error.
|
||||
//nolint:revive
|
||||
var FinishEarly = errors.New("finish stage early")
|
||||
|
||||
// NewScannerFunc is used to create a `Scanner` for scanning input
|
||||
// that is coming from `r`.
|
||||
type NewScannerFunc func(r io.Reader) (Scanner, error)
|
||||
|
||||
// ScannerFunction creates a function-based `Stage`. The function will
|
||||
// be passed input, one line at a time, and may emit output. See the
|
||||
// definition of `LinewiseStageFunc` for more information.
|
||||
func ScannerFunction(
|
||||
name string, newScanner NewScannerFunc, f LinewiseStageFunc,
|
||||
) Stage {
|
||||
stage := Function(
|
||||
name,
|
||||
func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) (theErr error) {
|
||||
scanner, err := newScanner(stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var out *bufio.Writer
|
||||
if stdout != nil {
|
||||
out = bufio.NewWriter(stdout)
|
||||
defer func() {
|
||||
err := out.Flush()
|
||||
if err != nil && theErr == nil {
|
||||
// Note: this sets the named return value,
|
||||
// thereby causing the whole stage to report
|
||||
// the error.
|
||||
theErr = err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
err := f(ctx, env, scanner.Bytes(), out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
// `p.AddFunction()` arranges for `stdout` to be closed.
|
||||
},
|
||||
)
|
||||
return IgnoreError(stage, IsError(FinishEarly))
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package pipe
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Stage is an element of a `Pipeline`.
|
||||
type Stage interface {
|
||||
// Name returns the name of the stage.
|
||||
Name() string
|
||||
|
||||
// Start starts the stage in the background, in the environment
|
||||
// described by `env`, and using `stdin` as input. (`stdin` should
|
||||
// be set to `nil` if the stage is to receive no input, which
|
||||
// might be the case for the first stage in a pipeline.) It
|
||||
// returns an `io.ReadCloser` from which the stage's output can be
|
||||
// read (or `nil` if it generates no output, which should only be
|
||||
// the case for the last stage in a pipeline). It is the stages'
|
||||
// responsibility to close `stdin` (if it is not nil) when it has
|
||||
// read all of the input that it needs, and to close the write end
|
||||
// of its output reader when it is done, as that is generally how
|
||||
// the subsequent stage knows that it has received all of its
|
||||
// input and can finish its work, too.
|
||||
//
|
||||
// If `Start()` returns without an error, `Wait()` must also be
|
||||
// called, to allow all resources to be freed.
|
||||
Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error)
|
||||
|
||||
// Wait waits for the stage to be done, either because it has
|
||||
// finished or because it has been killed due to the expiration of
|
||||
// the context passed to `Start()`.
|
||||
Wait() error
|
||||
}
|
Загрузка…
Ссылка в новой задаче