internal/pipe: new package for handling command pipelines

This commit is contained in:
Michael Haggerty 2021-11-02 09:30:01 +01:00
Родитель 6aa6890117
Коммит 1458ae5f8b
12 изменённых файлов: 1623 добавлений и 10 удалений

5
go.mod
Просмотреть файл

@ -6,6 +6,7 @@ require (
github.com/cli/safeexec v1.0.0
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.4.0
gopkg.in/yaml.v2 v2.2.7 // indirect
github.com/stretchr/testify v1.7.0
go.uber.org/goleak v1.1.12
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
)

50
go.sum
Просмотреть файл

@ -1,19 +1,53 @@
github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI=
github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA=
go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

223
internal/pipe/command.go Normal file
Просмотреть файл

@ -0,0 +1,223 @@
package pipe
import (
"bytes"
"context"
"errors"
"io"
"os"
"os/exec"
"sync/atomic"
"syscall"
"time"
"golang.org/x/sync/errgroup"
)
// commandStage is a pipeline `Stage` based on running an external
// command and piping the data through its stdin and stdout.
type commandStage struct {
name string
stdin io.Closer
cmd *exec.Cmd
done chan struct{}
wg errgroup.Group
stderr bytes.Buffer
// If the context expired and we attempted to kill the command,
// `ctx.Err()` is stored here.
ctxErr atomic.Value
}
// Command returns a pipeline `Stage` based on the specified external
// `command`, run with the given command-line `args`. Its stdin and
// stdout are handled as usual, and its stderr is collected and
// included in any `*exec.ExitError` that the command might emit.
func Command(command string, args ...string) Stage {
if len(command) == 0 {
panic("attempt to create command with empty command")
}
cmd := exec.Command(command, args...)
return CommandStage(command, cmd)
}
// Command returns a pipeline `Stage` with the name `name`, based on
// the specified `cmd`. Its stdin and stdout are handled as usual, and
// its stderr is collected and included in any `*exec.ExitError` that
// the command might emit.
func CommandStage(name string, cmd *exec.Cmd) Stage {
return &commandStage{
name: name,
cmd: cmd,
done: make(chan struct{}),
}
}
func (s *commandStage) Name() string {
return s.name
}
func (s *commandStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser,
) (io.ReadCloser, error) {
if s.cmd.Dir == "" {
s.cmd.Dir = env.Dir
}
if stdin != nil {
s.cmd.Stdin = stdin
// Also keep a copy so that we can close it when the command exits:
s.stdin = stdin
}
stdout, err := s.cmd.StdoutPipe()
if err != nil {
return nil, err
}
// If the caller hasn't arranged otherwise, read the command's
// standard error into our `stderr` field:
if s.cmd.Stderr == nil {
// We can't just set `s.cmd.Stderr = &s.stderr`, because if we
// do then `s.cmd.Wait()` doesn't wait to be sure that all
// error output has been captured. By doing this ourselves, we
// can be sure.
p, err := s.cmd.StderrPipe()
if err != nil {
return nil, err
}
s.wg.Go(func() error {
_, err := io.Copy(&s.stderr, p)
// We don't consider `ErrClosed` an error (FIXME: is this
// correct?):
if err != nil && !errors.Is(err, os.ErrClosed) {
return err
}
return nil
})
}
// Put the command in its own process group:
if s.cmd.SysProcAttr == nil {
s.cmd.SysProcAttr = &syscall.SysProcAttr{}
}
s.cmd.SysProcAttr.Setpgid = true
if err := s.cmd.Start(); err != nil {
return nil, err
}
// Arrange for the process to be killed (gently) if the context
// expires before the command exits normally:
go func() {
select {
case <-ctx.Done():
s.kill(ctx.Err())
case <-s.done:
// Process already done; no need to kill anything.
}
}()
return stdout, nil
}
// kill is called to kill the process if the context expires. `err` is
// the corresponding value of `Context.Err()`.
func (s *commandStage) kill(err error) {
// I believe that the calls to `syscall.Kill()` in this method are
// racy. It could be that s.cmd.Wait() succeeds immediately before
// this call, in which case the process group wouldn't exist
// anymore. But I don't see any way to avoid this without
// duplicating a lot of code from `exec.Cmd`. (`os.Cmd.Kill()` and
// `os.Cmd.Signal()` appear to be race-free, but only because they
// use internal synchronization. But those methods only kill the
// process, not the process group, so they are not suitable here.
// We started the process with PGID == PID:
pid := s.cmd.Process.Pid
select {
case <-s.done:
// Process has ended; no need to kill it again.
return
default:
}
// Record the `ctx.Err()`, which will be used as the error result
// for this stage.
s.ctxErr.Store(err)
// First try to kill using a relatively gentle signal so that
// the processes have a chance to clean up after themselves:
_ = syscall.Kill(-pid, syscall.SIGTERM)
// Well-behaved processes should commit suicide after the above,
// but if they don't exit within 2s, murder the whole lot of them:
go func() {
// Use an explicit `time.Timer` rather than `time.After()` so
// that we can stop it (freeing resources) promptly if the
// command exits before the timer triggers.
timer := time.NewTimer(2 * time.Second)
defer timer.Stop()
select {
case <-s.done:
// Process has ended; no need to kill it again.
case <-timer.C:
_ = syscall.Kill(-pid, syscall.SIGKILL)
}
}()
}
// filterCmdError interprets `err`, which was returned by `Cmd.Wait()`
// (possibly `nil`), possibly modifying it or ignoring it. It returns
// the error that should actually be returned to the caller (possibly
// `nil`).
func (s *commandStage) filterCmdError(err error) error {
if err == nil {
return nil
}
eErr, ok := err.(*exec.ExitError)
if !ok {
return err
}
ctxErr, ok := s.ctxErr.Load().(error)
if ok {
// If the process looks like it was killed by us, substitute
// `ctxErr` for the process's own exit error.
ps, ok := eErr.ProcessState.Sys().(syscall.WaitStatus)
if ok && ps.Signaled() &&
(ps.Signal() == syscall.SIGTERM || ps.Signal() == syscall.SIGKILL) {
return ctxErr
}
}
eErr.Stderr = s.stderr.Bytes()
return eErr
}
func (s *commandStage) Wait() error {
defer close(s.done)
// Make sure that any stderr is copied before `s.cmd.Wait()`
// closes the read end of the pipe:
wErr := s.wg.Wait()
err := s.cmd.Wait()
err = s.filterCmdError(err)
if err == nil && wErr != nil {
err = wErr
}
if s.stdin != nil {
cErr := s.stdin.Close()
if cErr != nil && err == nil {
return cErr
}
}
return err
}

Просмотреть файл

@ -0,0 +1,132 @@
package pipe
import (
"errors"
"io"
"os/exec"
"syscall"
)
// ErrorFilter is a function that can filter errors from
// `Stage.Wait()`. The original error (possibly nil) is passed in as
// an argument, and whatever the function returns is the error
// (possibly nil) that is actually emitted.
type ErrorFilter func(err error) error
func FilterError(s Stage, filter ErrorFilter) Stage {
return efStage{Stage: s, filter: filter}
}
type efStage struct {
Stage
filter ErrorFilter
}
func (s efStage) Wait() error {
return s.filter(s.Stage.Wait())
}
// ErrorMatcher decides whether its argument matches some class of
// errors (e.g., errors that we want to ignore). The function will
// only be invoked for non-nil errors.
type ErrorMatcher func(err error) bool
// IgnoreError creates a stage that acts like `s` except that it
// ignores any errors that are matched by `em`. Use like
//
// p.Add(pipe.IgnoreError(
// someStage,
// func(err error) bool {
// var myError *MyErrorType
// return errors.As(err, &myError) && myError.foo == 42
// },
// )
//
// The second argument can also be one of the `ErrorMatcher`s that are
// provided by this package (e.g., `IsError(target)`,
// IsSignal(signal), `IsSIGPIPE`, `IsEPIPE`, `IsPipeError`), or one of
// the functions from the standard library that has the same signature
// (e.g., `os.IsTimeout`), or some combination of these (e.g.,
// `AnyError(IsSIGPIPE, os.IsTimeout)`).
func IgnoreError(s Stage, em ErrorMatcher) Stage {
return FilterError(s,
func(err error) error {
if err == nil || em(err) {
return nil
}
return err
},
)
}
// AnyError returns an `ErrorMatcher` that returns true for an error
// that matches any of the `ems`.
func AnyError(ems ...ErrorMatcher) ErrorMatcher {
return func(err error) bool {
if err == nil {
return false
}
for _, em := range ems {
if em(err) {
return true
}
}
return false
}
}
// IsError returns an ErrorIdentifier for the specified target error,
// matched using `errors.Is()`. Use like
//
// p.Add(pipe.IgnoreError(someStage, IsError(io.EOF)))
func IsError(target error) ErrorMatcher {
return func(err error) bool {
return errors.Is(err, target)
}
}
// IsSIGPIPE returns an `ErrorMatcher` that matches `*exec.ExitError`s
// that were caused by the specified signal. The match for
// `*exec.ExitError`s uses `errors.As()`.
func IsSignal(signal syscall.Signal) ErrorMatcher {
return func(err error) bool {
var eErr *exec.ExitError
if !errors.As(err, &eErr) {
return false
}
status, ok := eErr.Sys().(syscall.WaitStatus)
return ok && status.Signaled() && status.Signal() == signal
}
}
var (
// IsSIGPIPE is an `ErrorMatcher` that matches `*exec.ExitError`s
// that were caused by SIGPIPE. The match for `*exec.ExitError`s
// uses `errors.As()`. Use like
//
// p.Add(IgnoreError(someStage, IsSIGPIPE))
IsSIGPIPE = IsSignal(syscall.SIGPIPE)
// IsEPIPE is an `ErrorMatcher` that matches `syscall.EPIPE` using
// `errors.Is()`. Use like
//
// p.Add(IgnoreError(someStage, IsEPIPE))
IsEPIPE = IsError(syscall.EPIPE)
// IsErrClosedPipe is an `ErrorMatcher` that matches
// `io.ErrClosedPipe` using `errors.Is()`. (`io.ErrClosedPipe` is
// the error that results from writing to a closed
// `*io.PipeWriter`.) Use like
//
// p.Add(IgnoreError(someStage, IsErrClosedPipe))
IsErrClosedPipe = IsError(io.ErrClosedPipe)
// IsPipeError is an `ErrorMatcher` that matches a few different
// errors that typically result if a stage writes to a subsequent
// stage that has stopped reading from its stdin. Use like
//
// p.Add(IgnoreError(someStage, IsPipeError))
IsPipeError = AnyError(IsSIGPIPE, IsEPIPE, IsErrClosedPipe)
)

66
internal/pipe/function.go Normal file
Просмотреть файл

@ -0,0 +1,66 @@
package pipe
import (
"context"
"fmt"
"io"
)
// StageFunc is a function that can be used to power a `goStage`. It
// should read its input from `stdin` and write its output to
// `stdout`. `stdin` and `stdout` will be closed automatically (if
// necessary) once the function returns.
//
// Neither `stdin` nor `stdout` are necessarily buffered. If the
// `StageFunc` requires buffering, it needs to arrange that itself.
//
// A `StageFunc` is run in a separate goroutine, so it must be careful
// to synchronize any data access aside from reading and writing.
type StageFunc func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) error
// Function returns a pipeline `Stage` that will run a `StageFunc` in
// a separate goroutine to process the data. See `StageFunc` for more
// information.
func Function(name string, f StageFunc) Stage {
return &goStage{
name: name,
f: f,
done: make(chan struct{}),
}
}
// goStage is a `Stage` that does its work by running an arbitrary
// `stageFunc` in a goroutine.
type goStage struct {
name string
f StageFunc
done chan struct{}
err error
}
func (s *goStage) Name() string {
return s.name
}
func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
r, w := io.Pipe()
go func() {
s.err = s.f(ctx, env, stdin, w)
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
}
if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
}
close(s.done)
}()
return r, nil
}
func (s *goStage) Wait() error {
<-s.done
return s.err
}

62
internal/pipe/iocopier.go Normal file
Просмотреть файл

@ -0,0 +1,62 @@
package pipe
import (
"context"
"errors"
"io"
"os"
)
// ioCopier is a stage that copies its stdin to a specified
// `io.Writer`. It generates no stdout itself.
type ioCopier struct {
w io.WriteCloser
done chan struct{}
err error
}
func newIOCopier(w io.WriteCloser) *ioCopier {
return &ioCopier{
w: w,
done: make(chan struct{}),
}
}
func (s *ioCopier) Name() string {
return "ioCopier"
}
// This method always returns `nil, nil`.
func (s *ioCopier) Start(ctx context.Context, _ Env, r io.ReadCloser) (io.ReadCloser, error) {
go func() {
_, err := io.Copy(s.w, r)
// We don't consider `ErrClosed` an error (FIXME: is this
// correct?):
if err != nil && !errors.Is(err, os.ErrClosed) {
s.err = err
}
if err := r.Close(); err != nil && s.err == nil {
s.err = err
}
if err := s.w.Close(); err != nil && s.err == nil {
s.err = err
}
close(s.done)
}()
// FIXME: if `s.w.Write()` is blocking (e.g., because there is a
// downstream process that is not reading from the other side),
// there's no way to terminate the copy when the context expires.
// This is not too bad, because the `io.Copy()` call will exit by
// itself when its input is closed.
//
// We could, however, be smarter about exiting more quickly if the
// context expires but `s.w.Write()` is not blocking.
return nil, nil
}
func (s *ioCopier) Wait() error {
<-s.done
return s.err
}

74
internal/pipe/linewise.go Normal file
Просмотреть файл

@ -0,0 +1,74 @@
package pipe
import (
"bufio"
"bytes"
"context"
"io"
)
// LinewiseStageFunc is a function that can be embedded in a
// `goStage`. It is called once per line in the input (where "line"
// can be defined via any `bufio.Scanner`). It should process the line
// and may write whatever it likes to `stdout`, which is a buffered
// writer whose contents are forwarded to the input of the next stage
// of the pipeline. The function needn't write one line of output per
// line of input.
//
// The function mustn't retain copies of `line`, since it may be
// overwritten every time the function is called.
//
// The function needn't flush or close `stdout` (this will be done
// automatically when all of the input has been processed).
//
// If there is an error parsing the input into lines, or if this
// function returns an error, then the whole pipeline will be aborted
// with that error. However, if the function returns the special error
// `pipe.FinishEarly`, the stage will stop processing immediately with
// a `nil` error value.
//
// The function will be called in a separate goroutine, so it must be
// careful to synchronize any data access aside from writing to
// `stdout`.
type LinewiseStageFunc func(
ctx context.Context, env Env, line []byte, stdout *bufio.Writer,
) error
// LinewiseFunction returns a function-based `Stage`. The input will
// be split into LF-terminated lines and passed to the function one
// line at a time (without the LF). The function may emit output to
// its `stdout` argument. See the definition of `LinewiseStageFunc`
// for more information.
//
// Note that the stage will emit an error if any line (including its
// end-of-line terminator) exceeds 64 kiB in length. If this is too
// short, use `ScannerFunction()` directly with your own
// `NewScannerFunc` as argument, or use `Function()` directly with
// your own `StageFunc`.
func LinewiseFunction(name string, f LinewiseStageFunc) Stage {
return ScannerFunction(
name,
func(r io.Reader) (Scanner, error) {
scanner := bufio.NewScanner(r)
// Split based on strict LF (we don't accept CRLF):
scanner.Split(ScanLFTerminatedLines)
return scanner, nil
},
f,
)
}
// ScanLFTerminatedLines is a `bufio.SplitFunc` that splits its input
// into lines at LF characters (not treating CR specially).
func ScanLFTerminatedLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\n'); i != -1 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
}

211
internal/pipe/pipeline.go Normal file
Просмотреть файл

@ -0,0 +1,211 @@
package pipe
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"sync/atomic"
)
// Env represents the environment that a pipeline stage should run in.
// It is passed to `Stage.Start()`.
type Env struct {
// The directory in which external commands should be executed by
// default.
Dir string
}
// Pipeline represents a Unix-like pipe that can include multiple
// stages, including external processes but also and stages written in
// Go.
type Pipeline struct {
env Env
stdin io.Reader
stdout io.WriteCloser
stages []Stage
cancel func()
// Atomically written and read value, nonzero if the pipeline has
// been started. This is only used for lifecycle sanity checks but
// does not guarantee that clients are using the class correctly.
started uint32
}
type nopWriteCloser struct {
io.Writer
}
func (w nopWriteCloser) Close() error {
return nil
}
// NewPipeline returns a Pipeline struct with all of the `options`
// applied.
func New(options ...Option) *Pipeline {
p := &Pipeline{}
for _, option := range options {
option(p)
}
return p
}
// Option is a type alias for Pipeline functional options.
type Option func(*Pipeline)
// WithDir sets the default directory for running external commands.
func WithDir(dir string) Option {
return func(p *Pipeline) {
p.env.Dir = dir
}
}
// WithStdin assigns stdin to the first command in the pipeline.
func WithStdin(stdin io.Reader) Option {
return func(p *Pipeline) {
p.stdin = stdin
}
}
// WithStdout assigns stdout to the last command in the pipeline.
func WithStdout(stdout io.Writer) Option {
return func(p *Pipeline) {
p.stdout = nopWriteCloser{stdout}
}
}
// WithStdoutCloser assigns stdout to the last command in the
// pipeline, and closes stdout when it's done.
func WithStdoutCloser(stdout io.WriteCloser) Option {
return func(p *Pipeline) {
p.stdout = stdout
}
}
func (p *Pipeline) hasStarted() bool {
return atomic.LoadUint32(&p.started) != 0
}
// Add appends one or more stages to the pipeline.
func (p *Pipeline) Add(stages ...Stage) {
if p.hasStarted() {
panic("attempt to modify a pipeline that has already started")
}
p.stages = append(p.stages, stages...)
}
// AddWithIgnoredError appends one or more stages that are ignoring
// the passed in error to the pipeline.
func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) {
if p.hasStarted() {
panic("attempt to modify a pipeline that has already started")
}
for _, stage := range stages {
p.stages = append(p.stages, IgnoreError(stage, em))
}
}
// Start starts the commands in the pipeline. If `Start()` exits
// without an error, `Wait()` must also be called, to allow all
// resources to be freed.
func (p *Pipeline) Start(ctx context.Context) error {
if p.hasStarted() {
panic("attempt to start a pipeline that has already started")
}
atomic.StoreUint32(&p.started, 1)
ctx, p.cancel = context.WithCancel(ctx)
var nextStdin io.ReadCloser
if p.stdin != nil {
// We don't want the first stage to actually close this, and
// it's not even an `io.ReadCloser`, so fake it:
nextStdin = ioutil.NopCloser(p.stdin)
}
for i, s := range p.stages {
var err error
stdout, err := s.Start(ctx, p.env, nextStdin)
if err != nil {
// Close the pipe that the previous stage was writing to.
// That should cause it to exit even if it's not minding
// its context.
if nextStdin != nil {
_ = nextStdin.Close()
}
// Kill and wait for any stages that have been started
// already to finish:
p.cancel()
for _, s := range p.stages[:i] {
_ = s.Wait()
}
return fmt.Errorf("starting pipeline stage %q: %w", s.Name(), err)
}
nextStdin = stdout
}
// If the pipeline was configured with a `stdout`, add a synthetic
// stage to copy the last stage's stdout to that writer:
if p.stdout != nil {
c := newIOCopier(p.stdout)
p.stages = append(p.stages, c)
// `ioCopier.Start()` never fails:
_, _ = c.Start(ctx, p.env, nextStdin)
}
return nil
}
func (p *Pipeline) Output(ctx context.Context) ([]byte, error) {
var buf bytes.Buffer
p.stdout = nopWriteCloser{&buf}
err := p.Run(ctx)
return buf.Bytes(), err
}
// Wait waits for each stage in the pipeline to exit.
func (p *Pipeline) Wait() error {
if !p.hasStarted() {
panic("unable to wait on a pipeline that has not started")
}
// Make sure that all of the cleanup eventually happens:
defer p.cancel()
var earliestStageErr error
var earliestFailedStage Stage
for i := len(p.stages) - 1; i >= 0; i-- {
s := p.stages[i]
err := s.Wait()
if err != nil {
// Overwrite any existing values here so that we end up
// retaining the last error that we see; i.e., the error
// that happened earliest in the pipeline.
earliestStageErr = err
earliestFailedStage = s
}
}
if earliestStageErr != nil {
return fmt.Errorf("%s: %w", earliestFailedStage.Name(), earliestStageErr)
}
return nil
}
// Run starts and waits for the commands in the pipeline.
func (p *Pipeline) Run(ctx context.Context) error {
if err := p.Start(ctx); err != nil {
return err
}
return p.Wait()
}

Просмотреть файл

@ -0,0 +1,664 @@
package pipe_test
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"github.com/github/git-sizer/internal/pipe"
)
func TestMain(m *testing.M) {
// Check whether this package's test suite leaks any goroutines:
goleak.VerifyTestMain(m)
}
func TestPipelineFirstStageFailsToStart(t *testing.T) {
t.Parallel()
ctx := context.Background()
startErr := errors.New("foo")
p := pipe.New()
p.Add(
ErrorStartingStage{startErr},
ErrorStartingStage{errors.New("this error should never happen")},
)
assert.ErrorIs(t, p.Run(ctx), startErr)
}
func TestPipelineSecondStageFailsToStart(t *testing.T) {
t.Parallel()
ctx := context.Background()
startErr := errors.New("foo")
p := pipe.New()
p.Add(
seqFunction(20000),
ErrorStartingStage{startErr},
)
assert.ErrorIs(t, p.Run(ctx), startErr)
}
func TestPipelineSingleCommandOutput(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(pipe.Command("echo", "hello world"))
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.EqualValues(t, "hello world\n", out)
}
}
func TestPipelineSingleCommandWithStdout(t *testing.T) {
t.Parallel()
ctx := context.Background()
stdout := &bytes.Buffer{}
p := pipe.New(pipe.WithStdout(stdout))
p.Add(pipe.Command("echo", "hello world"))
if assert.NoError(t, p.Run(ctx)) {
assert.Equal(t, "hello world\n", stdout.String())
}
}
func TestNontrivialPipeline(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(
pipe.Command("echo", "hello world"),
pipe.Command("sed", "s/hello/goodbye/"),
)
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.EqualValues(t, "goodbye world\n", out)
}
}
func TestPipelineReadFromSlowly(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
r, w := io.Pipe()
var buf []byte
readErr := make(chan error, 1)
go func() {
time.Sleep(200 * time.Millisecond)
var err error
buf, err = ioutil.ReadAll(r)
readErr <- err
}()
p := pipe.New(pipe.WithStdout(w))
p.Add(pipe.Command("echo", "hello world"))
assert.NoError(t, p.Run(ctx))
time.Sleep(100 * time.Millisecond)
// It's not super-intuitive, but `w` has to be closed here so that
// the `ioutil.ReadAll()` call above knows that it's done:
_ = w.Close()
assert.NoError(t, <-readErr)
assert.Equal(t, "hello world\n", string(buf))
}
func TestPipelineReadFromSlowly2(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
r, w := io.Pipe()
var buf []byte
readErr := make(chan error, 1)
go func() {
time.Sleep(100 * time.Millisecond)
for {
var c [1]byte
_, err := r.Read(c[:])
if err != nil {
if err == io.EOF {
readErr <- nil
return
}
readErr <- err
return
}
buf = append(buf, c[0])
time.Sleep(1 * time.Millisecond)
}
}()
p := pipe.New(pipe.WithStdout(w))
p.Add(pipe.Command("seq", "100"))
assert.NoError(t, p.Run(ctx))
time.Sleep(200 * time.Millisecond)
// It's not super-intuitive, but `w` has to be closed here so that
// the `ioutil.ReadAll()` call above knows that it's done:
_ = w.Close()
assert.NoError(t, <-readErr)
assert.Equal(t, 292, len(buf))
}
func TestPipelineTwoCommandsPiping(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(pipe.Command("echo", "hello world"))
assert.Panics(t, func() { p.Add(pipe.Command("")) })
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.EqualValues(t, "hello world\n", out)
}
}
func TestPipelineDir(t *testing.T) {
t.Parallel()
ctx := context.Background()
wdir, err := os.Getwd()
require.NoError(t, err)
dir, err := ioutil.TempDir(wdir, "pipeline-test-")
require.NoError(t, err)
defer os.RemoveAll(dir)
p := pipe.New(pipe.WithDir(dir))
p.Add(pipe.Command("pwd"))
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.Equal(t, dir, strings.TrimSuffix(string(out), "\n"))
}
}
func TestPipelineExit(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(
pipe.Command("false"),
pipe.Command("true"),
)
assert.EqualError(t, p.Run(ctx), "false: exit status 1")
}
func TestPipelineStderr(t *testing.T) {
t.Parallel()
ctx := context.Background()
dir, err := ioutil.TempDir("", "pipeline-test-")
require.NoError(t, err)
defer os.RemoveAll(dir)
p := pipe.New(pipe.WithDir(dir))
p.Add(pipe.Command("ls", "doesnotexist"))
_, err = p.Output(ctx)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "ls: exit status")
}
}
func TestPipelineInterrupted(t *testing.T) {
t.Parallel()
stdout := &bytes.Buffer{}
p := pipe.New(pipe.WithStdout(stdout))
p.Add(pipe.Command("sleep", "10"))
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := p.Start(ctx)
require.NoError(t, err)
err = p.Wait()
assert.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestPipelineCanceled(t *testing.T) {
t.Parallel()
stdout := &bytes.Buffer{}
p := pipe.New(pipe.WithStdout(stdout))
p.Add(pipe.Command("sleep", "10"))
ctx, cancel := context.WithCancel(context.Background())
err := p.Start(ctx)
require.NoError(t, err)
cancel()
err = p.Wait()
assert.ErrorIs(t, err, context.Canceled)
}
// Verify the correct error if a command in the pipeline exits before
// reading all of its predecessor's output. Note that the amount of
// unread output in this case *does fit* within the OS-level pipe
// buffer.
func TestLittleEPIPE(t *testing.T) {
t.Parallel()
p := pipe.New()
p.Add(
pipe.Command("sh", "-c", "sleep 1; echo foo"),
pipe.Command("true"),
)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := p.Run(ctx)
assert.EqualError(t, err, "sh: signal: broken pipe")
}
// Verify the correct error if one command in the pipeline exits
// before reading all of its predecessor's output. Note that the
// amount of unread output in this case *does not fit* within the
// OS-level pipe buffer.
func TestBigEPIPE(t *testing.T) {
t.Parallel()
p := pipe.New()
p.Add(
pipe.Command("seq", "100000"),
pipe.Command("true"),
)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := p.Run(ctx)
assert.EqualError(t, err, "seq: signal: broken pipe")
}
// Verify the correct error if one command in the pipeline exits
// before reading all of its predecessor's output. Note that the
// amount of unread output in this case *does not fit* within the
// OS-level pipe buffer.
func TestIgnoredSIGPIPE(t *testing.T) {
t.Parallel()
p := pipe.New()
p.Add(
pipe.IgnoreError(pipe.Command("seq", "100000"), pipe.IsSIGPIPE),
pipe.Command("echo", "foo"),
)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "foo\n", out)
}
func TestFunction(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(
pipe.Print("hello world"),
pipe.Function(
"farewell",
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
buf, err := ioutil.ReadAll(stdin)
if err != nil {
return err
}
if string(buf) != "hello world" {
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
}
_, err = stdout.Write([]byte("goodbye, cruel world"))
return err
},
),
)
out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "goodbye, cruel world", out)
}
func TestPipelineWithFunction(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(
pipe.Command("echo", "-n", "hello world"),
pipe.Function(
"farewell",
func(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
buf, err := ioutil.ReadAll(stdin)
if err != nil {
return err
}
if string(buf) != "hello world" {
return fmt.Errorf("expected \"hello world\"; got %q", string(buf))
}
_, err = stdout.Write([]byte("goodbye, cruel world"))
return err
},
),
pipe.Command("tr", "a-z", "A-Z"),
)
out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "GOODBYE, CRUEL WORLD", out)
}
type ErrorStartingStage struct {
err error
}
func (s ErrorStartingStage) Name() string {
return "errorStartingStage"
}
func (s ErrorStartingStage) Start(
ctx context.Context, env pipe.Env, stdin io.ReadCloser,
) (io.ReadCloser, error) {
return ioutil.NopCloser(&bytes.Buffer{}), s.err
}
func (s ErrorStartingStage) Wait() error {
return nil
}
func seqFunction(n int) pipe.Stage {
return pipe.Function(
"seq",
func(_ context.Context, _ pipe.Env, _ io.Reader, stdout io.Writer) error {
for i := 1; i <= n; i++ {
_, err := fmt.Fprintf(stdout, "%d\n", i)
if err != nil {
return err
}
}
return nil
},
)
}
func TestPipelineWithLinewiseFunction(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
// Print the numbers from 1 to 20 (generated from scratch):
p.Add(
seqFunction(20),
// Discard all but the multiples of 5, and emit the results
// separated by spaces on one line:
pipe.LinewiseFunction(
"multiples-of-5",
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
n, err := strconv.Atoi(string(line))
if err != nil {
return err
}
if n%5 != 0 {
return nil
}
_, err = fmt.Fprintf(w, " %d", n)
return err
},
),
// Read the words and square them, emitting the results one per
// line:
pipe.ScannerFunction(
"square-multiples-of-5",
func(r io.Reader) (pipe.Scanner, error) {
scanner := bufio.NewScanner(r)
scanner.Split(bufio.ScanWords)
return scanner, nil
},
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
n, err := strconv.Atoi(string(line))
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "%d\n", n*n)
return err
},
),
)
out, err := p.Output(ctx)
assert.NoError(t, err)
assert.EqualValues(t, "25\n100\n225\n400\n", out)
}
func TestScannerAlwaysFlushes(t *testing.T) {
t.Parallel()
ctx := context.Background()
var length int64
p := pipe.New()
// Print the numbers from 1 to 20 (generated from scratch):
p.Add(
pipe.IgnoreError(
seqFunction(20),
pipe.IsPipeError,
),
// Pass the numbers through up to 7, then exit with an
// ignored error:
pipe.IgnoreError(
pipe.LinewiseFunction(
"error-after-7",
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
fmt.Fprintf(w, "%s\n", line)
if string(line) == "7" {
return errors.New("ignore")
}
return nil
},
),
func(err error) bool {
return err.Error() == "ignore"
},
),
// Read the numbers and add them into the sum:
pipe.Function(
"compute-length",
func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error {
var err error
length, err = io.Copy(ioutil.Discard, stdin)
return err
},
),
)
err := p.Run(ctx)
assert.NoError(t, err)
// Make sure that all of the bytes emitted before the second
// stage's error were received by the third stage:
assert.EqualValues(t, 14, length)
}
func TestScannerFinishEarly(t *testing.T) {
t.Parallel()
ctx := context.Background()
var length int64
p := pipe.New()
// Print the numbers from 1 to 20 (generated from scratch):
p.Add(
pipe.IgnoreError(
seqFunction(20),
pipe.IsPipeError,
),
// Pass the numbers through up to 7, then exit with an
// ignored error:
pipe.LinewiseFunction(
"finish-after-7",
func(_ context.Context, _ pipe.Env, line []byte, w *bufio.Writer) error {
fmt.Fprintf(w, "%s\n", line)
if string(line) == "7" {
return pipe.FinishEarly
}
return nil
},
),
// Read the numbers and add them into the sum:
pipe.Function(
"compute-length",
func(_ context.Context, _ pipe.Env, stdin io.Reader, _ io.Writer) error {
var err error
length, err = io.Copy(ioutil.Discard, stdin)
return err
},
),
)
err := p.Run(ctx)
assert.NoError(t, err)
// Make sure that all of the bytes emitted before the second
// stage's error were received by the third stage:
assert.EqualValues(t, 14, length)
}
func TestPrintln(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(pipe.Println("Look Ma, no hands!"))
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.EqualValues(t, "Look Ma, no hands!\n", out)
}
}
func TestPrintf(t *testing.T) {
t.Parallel()
ctx := context.Background()
p := pipe.New()
p.Add(pipe.Printf("Strangely recursive: %T", p))
out, err := p.Output(ctx)
if assert.NoError(t, err) {
assert.EqualValues(t, "Strangely recursive: *pipe.Pipeline", out)
}
}
func BenchmarkSingleProgram(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
p := pipe.New()
p.Add(
pipe.Command("true"),
)
assert.NoError(b, p.Run(ctx))
}
}
func BenchmarkTenPrograms(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
p := pipe.New()
p.Add(
pipe.Command("echo", "hello world"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
pipe.Command("cat"),
)
out, err := p.Output(ctx)
if assert.NoError(b, err) {
assert.EqualValues(b, "hello world\n", out)
}
}
}
func BenchmarkTenFunctions(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
p := pipe.New()
p.Add(
pipe.Println("hello world"),
pipe.Function("copy1", catFn),
pipe.Function("copy2", catFn),
pipe.Function("copy3", catFn),
pipe.Function("copy4", catFn),
pipe.Function("copy5", catFn),
pipe.Function("copy6", catFn),
pipe.Function("copy7", catFn),
pipe.Function("copy8", catFn),
pipe.Function("copy9", catFn),
)
out, err := p.Output(ctx)
if assert.NoError(b, err) {
assert.EqualValues(b, "hello world\n", out)
}
}
}
func BenchmarkTenMixedStages(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
p := pipe.New()
p.Add(
pipe.Command("echo", "hello world"),
pipe.Function("copy1", catFn),
pipe.Command("cat"),
pipe.Function("copy2", catFn),
pipe.Command("cat"),
pipe.Function("copy3", catFn),
pipe.Command("cat"),
pipe.Function("copy4", catFn),
pipe.Command("cat"),
pipe.Function("copy5", catFn),
)
out, err := p.Output(ctx)
if assert.NoError(b, err) {
assert.EqualValues(b, "hello world\n", out)
}
}
}
func catFn(_ context.Context, _ pipe.Env, stdin io.Reader, stdout io.Writer) error {
_, err := io.Copy(stdout, stdin)
return err
}

37
internal/pipe/print.go Normal file
Просмотреть файл

@ -0,0 +1,37 @@
package pipe
import (
"context"
"fmt"
"io"
)
func Print(a ...interface{}) Stage {
return Function(
"print",
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
_, err := fmt.Fprint(stdout, a...)
return err
},
)
}
func Println(a ...interface{}) Stage {
return Function(
"println",
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
_, err := fmt.Fprintln(stdout, a...)
return err
},
)
}
func Printf(format string, a ...interface{}) Stage {
return Function(
"printf",
func(_ context.Context, _ Env, _ io.Reader, stdout io.Writer) error {
_, err := fmt.Fprintf(stdout, format, a...)
return err
},
)
}

75
internal/pipe/scanner.go Normal file
Просмотреть файл

@ -0,0 +1,75 @@
package pipe
import (
"bufio"
"context"
"errors"
"io"
)
// Scanner defines the interface (which is implemented by
// `bufio.Scanner`) that is needed by `AddScannerFunction()`. See
// `bufio.Scanner` for how these methods should behave.
type Scanner interface {
Scan() bool
Bytes() []byte
Err() error
}
// FinishEarly is an error that can be returned by a
// `LinewiseStageFunc` to request that the iteration be ended early,
// without an error.
//nolint:revive
var FinishEarly = errors.New("finish stage early")
// NewScannerFunc is used to create a `Scanner` for scanning input
// that is coming from `r`.
type NewScannerFunc func(r io.Reader) (Scanner, error)
// ScannerFunction creates a function-based `Stage`. The function will
// be passed input, one line at a time, and may emit output. See the
// definition of `LinewiseStageFunc` for more information.
func ScannerFunction(
name string, newScanner NewScannerFunc, f LinewiseStageFunc,
) Stage {
stage := Function(
name,
func(ctx context.Context, env Env, stdin io.Reader, stdout io.Writer) (theErr error) {
scanner, err := newScanner(stdin)
if err != nil {
return err
}
var out *bufio.Writer
if stdout != nil {
out = bufio.NewWriter(stdout)
defer func() {
err := out.Flush()
if err != nil && theErr == nil {
// Note: this sets the named return value,
// thereby causing the whole stage to report
// the error.
theErr = err
}
}()
}
for scanner.Scan() {
if ctx.Err() != nil {
return ctx.Err()
}
err := f(ctx, env, scanner.Bytes(), out)
if err != nil {
return err
}
}
if err := scanner.Err(); err != nil {
return err
}
return nil
// `p.AddFunction()` arranges for `stdout` to be closed.
},
)
return IgnoreError(stage, IsError(FinishEarly))
}

34
internal/pipe/stage.go Normal file
Просмотреть файл

@ -0,0 +1,34 @@
package pipe
import (
"context"
"io"
)
// Stage is an element of a `Pipeline`.
type Stage interface {
// Name returns the name of the stage.
Name() string
// Start starts the stage in the background, in the environment
// described by `env`, and using `stdin` as input. (`stdin` should
// be set to `nil` if the stage is to receive no input, which
// might be the case for the first stage in a pipeline.) It
// returns an `io.ReadCloser` from which the stage's output can be
// read (or `nil` if it generates no output, which should only be
// the case for the last stage in a pipeline). It is the stages'
// responsibility to close `stdin` (if it is not nil) when it has
// read all of the input that it needs, and to close the write end
// of its output reader when it is done, as that is generally how
// the subsequent stage knows that it has received all of its
// input and can finish its work, too.
//
// If `Start()` returns without an error, `Wait()` must also be
// called, to allow all resources to be freed.
Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error)
// Wait waits for the stage to be done, either because it has
// finished or because it has been killed due to the expiration of
// the context passed to `Start()`.
Wait() error
}